diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index 7bf84b5..d4d5024 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -1,8 +1,8 @@ import os from functools import partial -from itertools import chain, starmap +from itertools import chain, tee -from funcy import rcompose, juxt, first, compose, second, chunks, curry +from funcy import rcompose, juxt, first, compose, second, chunks, identity from image_prediction.config import CONFIG from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor @@ -21,43 +21,54 @@ def load_pipeline(**kwargs): return pipeline +def parallel(*fs): + return lambda *args: (f(a) for f, a in zip(fs, args)) + + +def splat(f): + return lambda x: f(*x) + + +def inspect(x): + x = list(x) + import IPython + + IPython.embed() + return x + + class Pipeline: def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs): extractor = get_extractor(**kwargs) - batcher = compose(lift(list), partial(chunks, batch_size)) classifier = get_image_classifier(model_loader, model_identifier) - - left = compose(classifier, lift(first)) - right = lift(second) - formatter = get_formatter() + batcher = partial(chunks, batch_size) + + classify = compose(chain.from_iterable, lift(classifier), batcher) + + def join_prediction_and_metadata(prd, mdt): return {"classification": prd, **mdt} - - # -------- - # -- -- -- -- - # == == == == - # -- -- -- -- - # -------- - # -------- - - def inspect(x): - x = list(x) - import IPython - IPython.embed() - return x + # +--classify--+ + # --extract image metadata paris-->--split--| |--zip-->-join-pairs-->format-->return + # +--identity--+ self.pipe = rcompose( extractor, - batcher, - lift(juxt(left, right)), - starlift(zip), - lift(starlift(join_prediction_and_metadata)), - chain.from_iterable, + tee, + splat(parallel(*map(lift, (first, second)))), + splat(parallel(classify, identity)), + splat(zip), + starlift(join_prediction_and_metadata), formatter, ) def __call__(self, pdf: bytes, page_range: range = None): - yield from self.pipe(pdf, page_range=page_range) + r = self.pipe(pdf, page_range=page_range) + + r = list(r) + + + return r