refactor load_pipeline() so the model is only loaded once when the program starts

This commit is contained in:
Francisco Schulz 2023-02-01 14:49:32 +01:00
parent 8cce535301
commit a50ade4771
2 changed files with 20 additions and 13 deletions

View File

@ -4,12 +4,16 @@ import os
from glob import glob from glob import glob
from operator import truth from operator import truth
from image_prediction.pipeline import load_pipeline from image_prediction.config import CONFIG
from image_prediction.locations import MLRUNS_DIR
from image_prediction.pipeline import load_model, load_pipeline
from image_prediction.utils import get_logger from image_prediction.utils import get_logger
from image_prediction.utils.pdf_annotation import annotate_pdf from image_prediction.utils.pdf_annotation import annotate_pdf
logger = get_logger() logger = get_logger()
MODEL = load_model(MLRUNS_DIR, CONFIG)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -34,14 +38,16 @@ def process_pdf(pipeline, pdf_path, metadata=None, page_range=None):
predictions = list(pipeline(f.read(), page_range=page_range, metadata_per_image=metadata)) predictions = list(pipeline(f.read(), page_range=page_range, metadata_per_image=metadata))
annotate_pdf( annotate_pdf(
pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", f"_{truth(metadata)}_annotated.pdf"))) pdf_path,
predictions,
os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", f"_{truth(metadata)}_annotated.pdf"))),
) )
return predictions return predictions
def main(args): def main(args):
pipeline = load_pipeline(verbose=True, tolerance=3) pipeline = load_pipeline(model=MODEL, verbose=True, tolerance=3)
if os.path.isfile(args.input): if os.path.isfile(args.input):
pdf_paths = [args.input] pdf_paths = [args.input]

View File

@ -3,17 +3,21 @@ import io
import json import json
import logging import logging
from image_prediction.config import Config
from image_prediction.locations import CONFIG_FILE
from image_prediction.pipeline import load_pipeline
from image_prediction.utils.banner import load_banner
from image_prediction.utils.process_wrapping import wrap_in_process
from pyinfra import config from pyinfra import config
from pyinfra.queue.queue_manager import QueueManager from pyinfra.queue.queue_manager import QueueManager
from pyinfra.storage.storage import get_storage from pyinfra.storage.storage import get_storage
from image_prediction.config import CONFIG, Config
from image_prediction.locations import CONFIG_FILE, MLRUNS_DIR
from image_prediction.pipeline import load_model, load_pipeline
from image_prediction.utils.banner import load_banner
from image_prediction.utils.process_wrapping import wrap_in_process
PYINFRA_CONFIG = config.get_config() PYINFRA_CONFIG = config.get_config()
IMAGE_CONFIG = Config(CONFIG_FILE) IMAGE_CONFIG = Config(CONFIG_FILE)
MODEL = load_model(MLRUNS_DIR, CONFIG)
BUCKET = PYINFRA_CONFIG.storage_bucket
STORAGE = get_storage(PYINFRA_CONFIG)
logging.getLogger().addHandler(logging.StreamHandler()) logging.getLogger().addHandler(logging.StreamHandler())
logger = logging.getLogger("main") logger = logging.getLogger("main")
@ -26,17 +30,14 @@ logger.setLevel(PYINFRA_CONFIG.logging_level_root)
# FIXME: Find more fine-grained solution or if the problem occurs persistently for python services, # FIXME: Find more fine-grained solution or if the problem occurs persistently for python services,
# FIXME: move the process wrapper to a general module (see RED-4929). # FIXME: move the process wrapper to a general module (see RED-4929).
@wrap_in_process @wrap_in_process
def process_request(request_message): def process_request(request_message, bucket=BUCKET, storage=STORAGE, model=MODEL, img_config=IMAGE_CONFIG):
dossier_id = request_message["dossierId"] dossier_id = request_message["dossierId"]
file_id = request_message["fileId"] file_id = request_message["fileId"]
target_file_name = f"{dossier_id}/{file_id}.{request_message['targetFileExtension']}" target_file_name = f"{dossier_id}/{file_id}.{request_message['targetFileExtension']}"
response_file_name = f"{dossier_id}/{file_id}.{request_message['responseFileExtension']}" response_file_name = f"{dossier_id}/{file_id}.{request_message['responseFileExtension']}"
figure_data_file_name = f"{dossier_id}/{file_id}.FIGURE.json.gz" figure_data_file_name = f"{dossier_id}/{file_id}.FIGURE.json.gz"
bucket = PYINFRA_CONFIG.storage_bucket pipeline = load_pipeline(model=model, verbose=img_config.service.verbose, batch_size=img_config.service.batch_size)
storage = get_storage(PYINFRA_CONFIG)
pipeline = load_pipeline(verbose=IMAGE_CONFIG.service.verbose, batch_size=IMAGE_CONFIG.service.batch_size)
if storage.exists(bucket, target_file_name): if storage.exists(bucket, target_file_name):
should_publish_result = True should_publish_result = True