diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 29d3199..eb588d6 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -4,12 +4,16 @@ import os from glob import glob from operator import truth -from image_prediction.pipeline import load_pipeline +from image_prediction.config import CONFIG +from image_prediction.locations import MLRUNS_DIR +from image_prediction.pipeline import load_model, load_pipeline from image_prediction.utils import get_logger from image_prediction.utils.pdf_annotation import annotate_pdf logger = get_logger() +MODEL = load_model(MLRUNS_DIR, CONFIG) + def parse_args(): parser = argparse.ArgumentParser() @@ -34,14 +38,16 @@ def process_pdf(pipeline, pdf_path, metadata=None, page_range=None): predictions = list(pipeline(f.read(), page_range=page_range, metadata_per_image=metadata)) annotate_pdf( - pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", f"_{truth(metadata)}_annotated.pdf"))) + pdf_path, + predictions, + os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", f"_{truth(metadata)}_annotated.pdf"))), ) return predictions def main(args): - pipeline = load_pipeline(verbose=True, tolerance=3) + pipeline = load_pipeline(model=MODEL, verbose=True, tolerance=3) if os.path.isfile(args.input): pdf_paths = [args.input] diff --git a/src/serve.py b/src/serve.py index ece6a0b..de94f03 100644 --- a/src/serve.py +++ b/src/serve.py @@ -3,17 +3,21 @@ import io import json import logging -from image_prediction.config import Config -from image_prediction.locations import CONFIG_FILE -from image_prediction.pipeline import load_pipeline -from image_prediction.utils.banner import load_banner -from image_prediction.utils.process_wrapping import wrap_in_process from pyinfra import config from pyinfra.queue.queue_manager import QueueManager from pyinfra.storage.storage import get_storage +from image_prediction.config import CONFIG, Config +from image_prediction.locations import CONFIG_FILE, MLRUNS_DIR +from image_prediction.pipeline import load_model, load_pipeline +from image_prediction.utils.banner import load_banner +from image_prediction.utils.process_wrapping import wrap_in_process + PYINFRA_CONFIG = config.get_config() IMAGE_CONFIG = Config(CONFIG_FILE) +MODEL = load_model(MLRUNS_DIR, CONFIG) +BUCKET = PYINFRA_CONFIG.storage_bucket +STORAGE = get_storage(PYINFRA_CONFIG) logging.getLogger().addHandler(logging.StreamHandler()) logger = logging.getLogger("main") @@ -26,17 +30,14 @@ logger.setLevel(PYINFRA_CONFIG.logging_level_root) # FIXME: Find more fine-grained solution or if the problem occurs persistently for python services, # FIXME: move the process wrapper to a general module (see RED-4929). @wrap_in_process -def process_request(request_message): +def process_request(request_message, bucket=BUCKET, storage=STORAGE, model=MODEL, img_config=IMAGE_CONFIG): dossier_id = request_message["dossierId"] file_id = request_message["fileId"] target_file_name = f"{dossier_id}/{file_id}.{request_message['targetFileExtension']}" response_file_name = f"{dossier_id}/{file_id}.{request_message['responseFileExtension']}" figure_data_file_name = f"{dossier_id}/{file_id}.FIGURE.json.gz" - bucket = PYINFRA_CONFIG.storage_bucket - storage = get_storage(PYINFRA_CONFIG) - - pipeline = load_pipeline(verbose=IMAGE_CONFIG.service.verbose, batch_size=IMAGE_CONFIG.service.batch_size) + pipeline = load_pipeline(model=model, verbose=img_config.service.verbose, batch_size=img_config.service.batch_size) if storage.exists(bucket, target_file_name): should_publish_result = True