refactoring

This commit is contained in:
Matthias Bisping 2022-04-21 21:04:57 +02:00
parent 6c10b55ff8
commit 9a1446cccf

View File

@ -1,8 +1,8 @@
import os
from functools import partial
from itertools import chain, starmap
from itertools import chain, tee
from funcy import rcompose, juxt, first, compose, second, chunks, curry
from funcy import rcompose, juxt, first, compose, second, chunks, identity
from image_prediction.config import CONFIG
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
@ -21,43 +21,54 @@ def load_pipeline(**kwargs):
return pipeline
def parallel(*fs):
return lambda *args: (f(a) for f, a in zip(fs, args))
def splat(f):
return lambda x: f(*x)
def inspect(x):
x = list(x)
import IPython
IPython.embed()
return x
class Pipeline:
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
extractor = get_extractor(**kwargs)
batcher = compose(lift(list), partial(chunks, batch_size))
classifier = get_image_classifier(model_loader, model_identifier)
left = compose(classifier, lift(first))
right = lift(second)
formatter = get_formatter()
batcher = partial(chunks, batch_size)
classify = compose(chain.from_iterable, lift(classifier), batcher)
def join_prediction_and_metadata(prd, mdt):
return {"classification": prd, **mdt}
# --------
# -- -- -- --
# == == == ==
# -- -- -- --
# --------
# --------
def inspect(x):
x = list(x)
import IPython
IPython.embed()
return x
# +--classify--+
# --extract image metadata paris-->--split--| |--zip-->-join-pairs-->format-->return
# +--identity--+
self.pipe = rcompose(
extractor,
batcher,
lift(juxt(left, right)),
starlift(zip),
lift(starlift(join_prediction_and_metadata)),
chain.from_iterable,
tee,
splat(parallel(*map(lift, (first, second)))),
splat(parallel(classify, identity)),
splat(zip),
starlift(join_prediction_and_metadata),
formatter,
)
def __call__(self, pdf: bytes, page_range: range = None):
yield from self.pipe(pdf, page_range=page_range)
r = self.pipe(pdf, page_range=page_range)
r = list(r)
return r