refactor load_pipeline() so the model is only loaded once when the program starts
This commit is contained in:
parent
8cce535301
commit
a50ade4771
@ -4,12 +4,16 @@ import os
|
||||
from glob import glob
|
||||
from operator import truth
|
||||
|
||||
from image_prediction.pipeline import load_pipeline
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.locations import MLRUNS_DIR
|
||||
from image_prediction.pipeline import load_model, load_pipeline
|
||||
from image_prediction.utils import get_logger
|
||||
from image_prediction.utils.pdf_annotation import annotate_pdf
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
MODEL = load_model(MLRUNS_DIR, CONFIG)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -34,14 +38,16 @@ def process_pdf(pipeline, pdf_path, metadata=None, page_range=None):
|
||||
predictions = list(pipeline(f.read(), page_range=page_range, metadata_per_image=metadata))
|
||||
|
||||
annotate_pdf(
|
||||
pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", f"_{truth(metadata)}_annotated.pdf")))
|
||||
pdf_path,
|
||||
predictions,
|
||||
os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", f"_{truth(metadata)}_annotated.pdf"))),
|
||||
)
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
def main(args):
|
||||
pipeline = load_pipeline(verbose=True, tolerance=3)
|
||||
pipeline = load_pipeline(model=MODEL, verbose=True, tolerance=3)
|
||||
|
||||
if os.path.isfile(args.input):
|
||||
pdf_paths = [args.input]
|
||||
|
||||
21
src/serve.py
21
src/serve.py
@ -3,17 +3,21 @@ import io
|
||||
import json
|
||||
import logging
|
||||
|
||||
from image_prediction.config import Config
|
||||
from image_prediction.locations import CONFIG_FILE
|
||||
from image_prediction.pipeline import load_pipeline
|
||||
from image_prediction.utils.banner import load_banner
|
||||
from image_prediction.utils.process_wrapping import wrap_in_process
|
||||
from pyinfra import config
|
||||
from pyinfra.queue.queue_manager import QueueManager
|
||||
from pyinfra.storage.storage import get_storage
|
||||
|
||||
from image_prediction.config import CONFIG, Config
|
||||
from image_prediction.locations import CONFIG_FILE, MLRUNS_DIR
|
||||
from image_prediction.pipeline import load_model, load_pipeline
|
||||
from image_prediction.utils.banner import load_banner
|
||||
from image_prediction.utils.process_wrapping import wrap_in_process
|
||||
|
||||
PYINFRA_CONFIG = config.get_config()
|
||||
IMAGE_CONFIG = Config(CONFIG_FILE)
|
||||
MODEL = load_model(MLRUNS_DIR, CONFIG)
|
||||
BUCKET = PYINFRA_CONFIG.storage_bucket
|
||||
STORAGE = get_storage(PYINFRA_CONFIG)
|
||||
|
||||
logging.getLogger().addHandler(logging.StreamHandler())
|
||||
logger = logging.getLogger("main")
|
||||
@ -26,17 +30,14 @@ logger.setLevel(PYINFRA_CONFIG.logging_level_root)
|
||||
# FIXME: Find more fine-grained solution or if the problem occurs persistently for python services,
|
||||
# FIXME: move the process wrapper to a general module (see RED-4929).
|
||||
@wrap_in_process
|
||||
def process_request(request_message):
|
||||
def process_request(request_message, bucket=BUCKET, storage=STORAGE, model=MODEL, img_config=IMAGE_CONFIG):
|
||||
dossier_id = request_message["dossierId"]
|
||||
file_id = request_message["fileId"]
|
||||
target_file_name = f"{dossier_id}/{file_id}.{request_message['targetFileExtension']}"
|
||||
response_file_name = f"{dossier_id}/{file_id}.{request_message['responseFileExtension']}"
|
||||
figure_data_file_name = f"{dossier_id}/{file_id}.FIGURE.json.gz"
|
||||
|
||||
bucket = PYINFRA_CONFIG.storage_bucket
|
||||
storage = get_storage(PYINFRA_CONFIG)
|
||||
|
||||
pipeline = load_pipeline(verbose=IMAGE_CONFIG.service.verbose, batch_size=IMAGE_CONFIG.service.batch_size)
|
||||
pipeline = load_pipeline(model=model, verbose=img_config.service.verbose, batch_size=img_config.service.batch_size)
|
||||
|
||||
if storage.exists(bucket, target_file_name):
|
||||
should_publish_result = True
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user