import os from itertools import chain from funcy import rcompose, juxt, first, compose, second, chunks, curry from image_prediction.config import CONFIG from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor from image_prediction.locations import MLRUNS_DIR from image_prediction.utils.generic import lift, starlift os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" def load_pipeline(**kwargs): model_loader = get_mlflow_model_loader(MLRUNS_DIR) model_identifier = CONFIG.service.run_id pipeline = Pipeline(model_loader, model_identifier, progress_message="Processing document", **kwargs) return pipeline class Pipeline: def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs): extractor = get_extractor(**kwargs) batcher = lambda x: chunks(batch_size, x) classifier = get_image_classifier(model_loader, model_identifier) merger = lambda predictions, metadata: ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata)) formatter = get_formatter() left = compose(classifier, lift(first)) right = lift(second) # -------- # -- -- -- -- # == == == == # -- -- -- -- # -------- # -------- def inspect(x): import IPython IPython.embed() return x self.pipe = rcompose( extractor, batcher, lift(list), lift(juxt(left, right)), starlift(merger), chain.from_iterable, formatter, ) def __call__(self, pdf: bytes, page_range: range = None): yield from self.pipe(pdf, page_range=page_range)