refactor load_pipeline() so the model is only loaded once when the program starts
This commit is contained in:
parent
8cce535301
commit
a50ade4771
@ -4,12 +4,16 @@ import os
|
|||||||
from glob import glob
|
from glob import glob
|
||||||
from operator import truth
|
from operator import truth
|
||||||
|
|
||||||
from image_prediction.pipeline import load_pipeline
|
from image_prediction.config import CONFIG
|
||||||
|
from image_prediction.locations import MLRUNS_DIR
|
||||||
|
from image_prediction.pipeline import load_model, load_pipeline
|
||||||
from image_prediction.utils import get_logger
|
from image_prediction.utils import get_logger
|
||||||
from image_prediction.utils.pdf_annotation import annotate_pdf
|
from image_prediction.utils.pdf_annotation import annotate_pdf
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
MODEL = load_model(MLRUNS_DIR, CONFIG)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -34,14 +38,16 @@ def process_pdf(pipeline, pdf_path, metadata=None, page_range=None):
|
|||||||
predictions = list(pipeline(f.read(), page_range=page_range, metadata_per_image=metadata))
|
predictions = list(pipeline(f.read(), page_range=page_range, metadata_per_image=metadata))
|
||||||
|
|
||||||
annotate_pdf(
|
annotate_pdf(
|
||||||
pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", f"_{truth(metadata)}_annotated.pdf")))
|
pdf_path,
|
||||||
|
predictions,
|
||||||
|
os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", f"_{truth(metadata)}_annotated.pdf"))),
|
||||||
)
|
)
|
||||||
|
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
pipeline = load_pipeline(verbose=True, tolerance=3)
|
pipeline = load_pipeline(model=MODEL, verbose=True, tolerance=3)
|
||||||
|
|
||||||
if os.path.isfile(args.input):
|
if os.path.isfile(args.input):
|
||||||
pdf_paths = [args.input]
|
pdf_paths = [args.input]
|
||||||
|
|||||||
21
src/serve.py
21
src/serve.py
@ -3,17 +3,21 @@ import io
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from image_prediction.config import Config
|
|
||||||
from image_prediction.locations import CONFIG_FILE
|
|
||||||
from image_prediction.pipeline import load_pipeline
|
|
||||||
from image_prediction.utils.banner import load_banner
|
|
||||||
from image_prediction.utils.process_wrapping import wrap_in_process
|
|
||||||
from pyinfra import config
|
from pyinfra import config
|
||||||
from pyinfra.queue.queue_manager import QueueManager
|
from pyinfra.queue.queue_manager import QueueManager
|
||||||
from pyinfra.storage.storage import get_storage
|
from pyinfra.storage.storage import get_storage
|
||||||
|
|
||||||
|
from image_prediction.config import CONFIG, Config
|
||||||
|
from image_prediction.locations import CONFIG_FILE, MLRUNS_DIR
|
||||||
|
from image_prediction.pipeline import load_model, load_pipeline
|
||||||
|
from image_prediction.utils.banner import load_banner
|
||||||
|
from image_prediction.utils.process_wrapping import wrap_in_process
|
||||||
|
|
||||||
PYINFRA_CONFIG = config.get_config()
|
PYINFRA_CONFIG = config.get_config()
|
||||||
IMAGE_CONFIG = Config(CONFIG_FILE)
|
IMAGE_CONFIG = Config(CONFIG_FILE)
|
||||||
|
MODEL = load_model(MLRUNS_DIR, CONFIG)
|
||||||
|
BUCKET = PYINFRA_CONFIG.storage_bucket
|
||||||
|
STORAGE = get_storage(PYINFRA_CONFIG)
|
||||||
|
|
||||||
logging.getLogger().addHandler(logging.StreamHandler())
|
logging.getLogger().addHandler(logging.StreamHandler())
|
||||||
logger = logging.getLogger("main")
|
logger = logging.getLogger("main")
|
||||||
@ -26,17 +30,14 @@ logger.setLevel(PYINFRA_CONFIG.logging_level_root)
|
|||||||
# FIXME: Find more fine-grained solution or if the problem occurs persistently for python services,
|
# FIXME: Find more fine-grained solution or if the problem occurs persistently for python services,
|
||||||
# FIXME: move the process wrapper to a general module (see RED-4929).
|
# FIXME: move the process wrapper to a general module (see RED-4929).
|
||||||
@wrap_in_process
|
@wrap_in_process
|
||||||
def process_request(request_message):
|
def process_request(request_message, bucket=BUCKET, storage=STORAGE, model=MODEL, img_config=IMAGE_CONFIG):
|
||||||
dossier_id = request_message["dossierId"]
|
dossier_id = request_message["dossierId"]
|
||||||
file_id = request_message["fileId"]
|
file_id = request_message["fileId"]
|
||||||
target_file_name = f"{dossier_id}/{file_id}.{request_message['targetFileExtension']}"
|
target_file_name = f"{dossier_id}/{file_id}.{request_message['targetFileExtension']}"
|
||||||
response_file_name = f"{dossier_id}/{file_id}.{request_message['responseFileExtension']}"
|
response_file_name = f"{dossier_id}/{file_id}.{request_message['responseFileExtension']}"
|
||||||
figure_data_file_name = f"{dossier_id}/{file_id}.FIGURE.json.gz"
|
figure_data_file_name = f"{dossier_id}/{file_id}.FIGURE.json.gz"
|
||||||
|
|
||||||
bucket = PYINFRA_CONFIG.storage_bucket
|
pipeline = load_pipeline(model=model, verbose=img_config.service.verbose, batch_size=img_config.service.batch_size)
|
||||||
storage = get_storage(PYINFRA_CONFIG)
|
|
||||||
|
|
||||||
pipeline = load_pipeline(verbose=IMAGE_CONFIG.service.verbose, batch_size=IMAGE_CONFIG.service.batch_size)
|
|
||||||
|
|
||||||
if storage.exists(bucket, target_file_name):
|
if storage.exists(bucket, target_file_name):
|
||||||
should_publish_result = True
|
should_publish_result = True
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user