import os from functools import lru_cache, partial from itertools import chain, tee from funcy import rcompose, first, compose, second, chunks, identity, rpartial from tqdm import tqdm from image_prediction.config import CONFIG from image_prediction.default_objects import ( get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor, get_encoder, ) from image_prediction.locations import MLRUNS_DIR from image_prediction.utils.generic import lift, starlift os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @lru_cache(maxsize=None) def load_pipeline(**kwargs): model_loader = get_mlflow_model_loader(MLRUNS_DIR) model_identifier = CONFIG.service.mlflow_run_id pipeline = Pipeline(model_loader, model_identifier, **kwargs) return pipeline def parallel(*fs): return lambda *args: (f(a) for f, a in zip(fs, args)) def star(f): return lambda x: f(*x) class Pipeline: def __init__(self, model_loader, model_identifier, batch_size=16, verbose=False, **kwargs): self.verbose = verbose extract = get_extractor(**kwargs) classifier = get_image_classifier(model_loader, model_identifier) reformat = get_formatter() represent = get_encoder() split = compose(star(parallel(*map(lift, (first, first, second)))), rpartial(tee, 3)) classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size)) pairwise_apply = compose(star, parallel) join = compose(starlift(lambda prd, rpr, mdt: {"classification": prd, **mdt, "representation": rpr}), star(zip)) # />--classify--\ # --extract-->--split--+->--encode---->+--join-->reformat # \>--identity--/ self.pipe = rcompose( extract, # ... image-metadata-pairs as a stream split, # ... into an image stream and a metadata stream pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise join, # ... the streams by zipping reformat, # ... the items ) def __call__(self, pdf: bytes, page_range: range = None): yield from tqdm( self.pipe(pdf, page_range=page_range), desc="Processing images from document", unit=" images", disable=not self.verbose, )