refactoring

This commit is contained in:
Matthias Bisping 2022-04-21 21:20:18 +02:00
parent 9a1446cccf
commit 50b161192d

View File

@ -2,7 +2,7 @@ import os
from functools import partial
from itertools import chain, tee
from funcy import rcompose, juxt, first, compose, second, chunks, identity
from funcy import rcompose, first, compose, second, chunks, identity
from image_prediction.config import CONFIG
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
@ -29,46 +29,27 @@ def splat(f):
return lambda x: f(*x)
def inspect(x):
x = list(x)
import IPython
IPython.embed()
return x
class Pipeline:
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
extractor = get_extractor(**kwargs)
extract = get_extractor(**kwargs)
classifier = get_image_classifier(model_loader, model_identifier)
formatter = get_formatter()
reformat = get_formatter()
batcher = partial(chunks, batch_size)
split = compose(splat(parallel(*map(lift, (first, second)))), tee)
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), splat(zip))
classify = compose(chain.from_iterable, lift(classifier), batcher)
def join_prediction_and_metadata(prd, mdt):
return {"classification": prd, **mdt}
# +--classify--+
# --extract image metadata paris-->--split--| |--zip-->-join-pairs-->format-->return
# +--identity--+
# +>--classify--v
# --extract image metadata pairs-->--split--| |--join-->format
# +>--identity--^
self.pipe = rcompose(
extractor,
tee,
splat(parallel(*map(lift, (first, second)))),
splat(parallel(classify, identity)),
splat(zip),
starlift(join_prediction_and_metadata),
formatter,
extract, # ... image-metadata-pairs as a stream
split, # ... into an image stream and a metadata stream
splat(parallel(classify, identity)), # ... process streams independently
join, # ... the streams
reformat, # ... the items
)
def __call__(self, pdf: bytes, page_range: range = None):
r = self.pipe(pdf, page_range=page_range)
r = list(r)
return r
yield from self.pipe(pdf, page_range=page_range)