refactoring

This commit is contained in:
Matthias Bisping 2022-04-21 21:04:57 +02:00
parent 6c10b55ff8
commit 9a1446cccf

View File

@ -1,8 +1,8 @@
import os import os
from functools import partial from functools import partial
from itertools import chain, starmap from itertools import chain, tee
from funcy import rcompose, juxt, first, compose, second, chunks, curry from funcy import rcompose, juxt, first, compose, second, chunks, identity
from image_prediction.config import CONFIG from image_prediction.config import CONFIG
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
@ -21,43 +21,54 @@ def load_pipeline(**kwargs):
return pipeline return pipeline
def parallel(*fs):
return lambda *args: (f(a) for f, a in zip(fs, args))
def splat(f):
return lambda x: f(*x)
def inspect(x):
x = list(x)
import IPython
IPython.embed()
return x
class Pipeline: class Pipeline:
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs): def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
extractor = get_extractor(**kwargs) extractor = get_extractor(**kwargs)
batcher = compose(lift(list), partial(chunks, batch_size))
classifier = get_image_classifier(model_loader, model_identifier) classifier = get_image_classifier(model_loader, model_identifier)
left = compose(classifier, lift(first))
right = lift(second)
formatter = get_formatter() formatter = get_formatter()
batcher = partial(chunks, batch_size)
classify = compose(chain.from_iterable, lift(classifier), batcher)
def join_prediction_and_metadata(prd, mdt): def join_prediction_and_metadata(prd, mdt):
return {"classification": prd, **mdt} return {"classification": prd, **mdt}
# +--classify--+
# -------- # --extract image metadata paris-->--split--| |--zip-->-join-pairs-->format-->return
# -- -- -- -- # +--identity--+
# == == == ==
# -- -- -- --
# --------
# --------
def inspect(x):
x = list(x)
import IPython
IPython.embed()
return x
self.pipe = rcompose( self.pipe = rcompose(
extractor, extractor,
batcher, tee,
lift(juxt(left, right)), splat(parallel(*map(lift, (first, second)))),
starlift(zip), splat(parallel(classify, identity)),
lift(starlift(join_prediction_and_metadata)), splat(zip),
chain.from_iterable, starlift(join_prediction_and_metadata),
formatter, formatter,
) )
def __call__(self, pdf: bytes, page_range: range = None): def __call__(self, pdf: bytes, page_range: range = None):
yield from self.pipe(pdf, page_range=page_range) r = self.pipe(pdf, page_range=page_range)
r = list(r)
return r