import os from functools import partial from itertools import chain, tee from funcy import rcompose, juxt, first, compose, second, chunks, identity from image_prediction.config import CONFIG from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor from image_prediction.locations import MLRUNS_DIR from image_prediction.utils.generic import lift, starlift os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" def load_pipeline(**kwargs): model_loader = get_mlflow_model_loader(MLRUNS_DIR) model_identifier = CONFIG.service.run_id pipeline = Pipeline(model_loader, model_identifier, progress_message="Processing document", **kwargs) return pipeline def parallel(*fs): return lambda *args: (f(a) for f, a in zip(fs, args)) def splat(f): return lambda x: f(*x) def inspect(x): x = list(x) import IPython IPython.embed() return x class Pipeline: def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs): extractor = get_extractor(**kwargs) classifier = get_image_classifier(model_loader, model_identifier) formatter = get_formatter() batcher = partial(chunks, batch_size) classify = compose(chain.from_iterable, lift(classifier), batcher) def join_prediction_and_metadata(prd, mdt): return {"classification": prd, **mdt} # +--classify--+ # --extract image metadata paris-->--split--| |--zip-->-join-pairs-->format-->return # +--identity--+ self.pipe = rcompose( extractor, tee, splat(parallel(*map(lift, (first, second)))), splat(parallel(classify, identity)), splat(zip), starlift(join_prediction_and_metadata), formatter, ) def __call__(self, pdf: bytes, page_range: range = None): r = self.pipe(pdf, page_range=page_range) r = list(r) return r