refactoring

This commit is contained in:
Matthias Bisping 2022-04-21 19:43:39 +02:00
parent f036ee55e6
commit 72e785e3e3

View File

@ -1,5 +1,6 @@
import os
from itertools import chain
from functools import partial
from itertools import chain, starmap
from funcy import rcompose, juxt, first, compose, second, chunks, curry
@ -23,14 +24,21 @@ def load_pipeline(**kwargs):
class Pipeline:
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
extractor = get_extractor(**kwargs)
batcher = lambda x: chunks(batch_size, x)
batcher = compose(lift(list), partial(chunks, batch_size))
classifier = get_image_classifier(model_loader, model_identifier)
merger = lambda predictions, metadata: ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
formatter = get_formatter()
left = compose(classifier, lift(first))
right = lift(second)
formatter = get_formatter()
def join_prediction_and_metadata(prd, mdt):
return {"classification": prd, **mdt}
def process_batch(batch):
classifications, metadata = juxt(left, right)(batch)
return starmap(join_prediction_and_metadata, zip(classifications, metadata))
# --------
# -- -- -- --
# == == == ==
@ -39,6 +47,7 @@ class Pipeline:
# --------
def inspect(x):
x = list(x)
import IPython
IPython.embed()
return x
@ -46,9 +55,7 @@ class Pipeline:
self.pipe = rcompose(
extractor,
batcher,
lift(list),
lift(juxt(left, right)),
starlift(merger),
lift(process_batch),
chain.from_iterable,
formatter,
)