import os from functools import partial from itertools import chain, tee from typing import Iterable from funcy import rcompose, first, compose, second, chunks, identity, rpartial from tqdm import tqdm from image_prediction.config import CONFIG from image_prediction.default_objects import ( get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor, get_encoder, ) from image_prediction.locations import MLRUNS_DIR from image_prediction.utils.generic import lift, starlift os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" def load_pipeline(**kwargs): model_loader = get_mlflow_model_loader(MLRUNS_DIR) model_identifier = CONFIG.service.mlflow_run_id pipeline = Pipeline(model_loader, model_identifier, **kwargs) return pipeline def parallel(*fs): return lambda *args: (f(a) for f, a in zip(fs, args)) def star(f): return lambda x: f(*x) class Pipeline: def __init__(self, model_loader, model_identifier, batch_size=16, verbose=True, **kwargs): self.verbose = verbose extract = get_extractor(**kwargs) classifier = get_image_classifier(model_loader, model_identifier) reformat = get_formatter() represent = get_encoder() split = compose(star(parallel(*map(lift, (first, first, second)))), rpartial(tee, 3)) classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size)) pairwise_apply = compose(star, parallel) join = compose(starlift(lambda prd, rpr, mdt: {"classification": prd, **mdt, "representation": rpr}), star(zip)) # />--classify--\ # --extract-->--split--+->--encode---->+--join-->reformat # \>--identity--/ self.pipe = rcompose( extract, # ... image-metadata-pairs as a stream split, # ... into an image stream and a metadata stream pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise join, # ... the streams by zipping reformat, # ... the items ) def __call__(self, pdf: bytes, page_range: range = None): yield from tqdm( self.pipe(pdf, page_range=page_range), desc="Processing images from document", unit=" images", disable=not self.verbose, )