refactoring
This commit is contained in:
parent
f036ee55e6
commit
72e785e3e3
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from itertools import chain
|
||||
from functools import partial
|
||||
from itertools import chain, starmap
|
||||
|
||||
from funcy import rcompose, juxt, first, compose, second, chunks, curry
|
||||
|
||||
@ -23,14 +24,21 @@ def load_pipeline(**kwargs):
|
||||
class Pipeline:
|
||||
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
|
||||
extractor = get_extractor(**kwargs)
|
||||
batcher = lambda x: chunks(batch_size, x)
|
||||
batcher = compose(lift(list), partial(chunks, batch_size))
|
||||
classifier = get_image_classifier(model_loader, model_identifier)
|
||||
merger = lambda predictions, metadata: ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
|
||||
formatter = get_formatter()
|
||||
|
||||
left = compose(classifier, lift(first))
|
||||
right = lift(second)
|
||||
|
||||
formatter = get_formatter()
|
||||
|
||||
def join_prediction_and_metadata(prd, mdt):
|
||||
return {"classification": prd, **mdt}
|
||||
|
||||
def process_batch(batch):
|
||||
classifications, metadata = juxt(left, right)(batch)
|
||||
return starmap(join_prediction_and_metadata, zip(classifications, metadata))
|
||||
|
||||
# --------
|
||||
# -- -- -- --
|
||||
# == == == ==
|
||||
@ -39,6 +47,7 @@ class Pipeline:
|
||||
# --------
|
||||
|
||||
def inspect(x):
|
||||
x = list(x)
|
||||
import IPython
|
||||
IPython.embed()
|
||||
return x
|
||||
@ -46,9 +55,7 @@ class Pipeline:
|
||||
self.pipe = rcompose(
|
||||
extractor,
|
||||
batcher,
|
||||
lift(list),
|
||||
lift(juxt(left, right)),
|
||||
starlift(merger),
|
||||
lift(process_batch),
|
||||
chain.from_iterable,
|
||||
formatter,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user