diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index 0be6b55..7bf84b5 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -35,9 +35,6 @@ class Pipeline: def join_prediction_and_metadata(prd, mdt): return {"classification": prd, **mdt} - def process_batch(batch): - classifications, metadata = juxt(left, right)(batch) - return starmap(join_prediction_and_metadata, zip(classifications, metadata)) # -------- # -- -- -- -- @@ -55,7 +52,9 @@ class Pipeline: self.pipe = rcompose( extractor, batcher, - lift(process_batch), + lift(juxt(left, right)), + starlift(zip), + lift(starlift(join_prediction_and_metadata)), chain.from_iterable, formatter, )