From 50b161192db43a84464125c6d79650225e1010d6 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Thu, 21 Apr 2022 21:20:18 +0200 Subject: [PATCH] refactoring --- image_prediction/pipeline.py | 49 +++++++++++------------------------- 1 file changed, 15 insertions(+), 34 deletions(-) diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index d4d5024..9621ae9 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -2,7 +2,7 @@ import os from functools import partial from itertools import chain, tee -from funcy import rcompose, juxt, first, compose, second, chunks, identity +from funcy import rcompose, first, compose, second, chunks, identity from image_prediction.config import CONFIG from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor @@ -29,46 +29,27 @@ def splat(f): return lambda x: f(*x) -def inspect(x): - x = list(x) - import IPython - - IPython.embed() - return x - - class Pipeline: def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs): - extractor = get_extractor(**kwargs) + extract = get_extractor(**kwargs) classifier = get_image_classifier(model_loader, model_identifier) - formatter = get_formatter() + reformat = get_formatter() - batcher = partial(chunks, batch_size) + split = compose(splat(parallel(*map(lift, (first, second)))), tee) + classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size)) + join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), splat(zip)) - classify = compose(chain.from_iterable, lift(classifier), batcher) - - - def join_prediction_and_metadata(prd, mdt): - return {"classification": prd, **mdt} - - # +--classify--+ - # --extract image metadata paris-->--split--| |--zip-->-join-pairs-->format-->return - # +--identity--+ + # +>--classify--v + # --extract image metadata pairs-->--split--| |--join-->format + # +>--identity--^ self.pipe = rcompose( - extractor, - tee, - splat(parallel(*map(lift, (first, second)))), - splat(parallel(classify, identity)), - splat(zip), - starlift(join_prediction_and_metadata), - formatter, + extract, # ... image-metadata-pairs as a stream + split, # ... into an image stream and a metadata stream + splat(parallel(classify, identity)), # ... process streams independently + join, # ... the streams + reformat, # ... the items ) def __call__(self, pdf: bytes, page_range: range = None): - r = self.pipe(pdf, page_range=page_range) - - r = list(r) - - - return r + yield from self.pipe(pdf, page_range=page_range)