diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index 7392d72..0be6b55 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -1,5 +1,6 @@ import os -from itertools import chain +from functools import partial +from itertools import chain, starmap from funcy import rcompose, juxt, first, compose, second, chunks, curry @@ -23,14 +24,21 @@ def load_pipeline(**kwargs): class Pipeline: def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs): extractor = get_extractor(**kwargs) - batcher = lambda x: chunks(batch_size, x) + batcher = compose(lift(list), partial(chunks, batch_size)) classifier = get_image_classifier(model_loader, model_identifier) - merger = lambda predictions, metadata: ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata)) - formatter = get_formatter() left = compose(classifier, lift(first)) right = lift(second) + formatter = get_formatter() + + def join_prediction_and_metadata(prd, mdt): + return {"classification": prd, **mdt} + + def process_batch(batch): + classifications, metadata = juxt(left, right)(batch) + return starmap(join_prediction_and_metadata, zip(classifications, metadata)) + # -------- # -- -- -- -- # == == == == @@ -39,6 +47,7 @@ class Pipeline: # -------- def inspect(x): + x = list(x) import IPython IPython.embed() return x @@ -46,9 +55,7 @@ class Pipeline: self.pipe = rcompose( extractor, batcher, - lift(list), - lift(juxt(left, right)), - starlift(merger), + lift(process_batch), chain.from_iterable, formatter, )