refactoring
This commit is contained in:
parent
9a1446cccf
commit
50b161192d
@ -2,7 +2,7 @@ import os
|
||||
from functools import partial
|
||||
from itertools import chain, tee
|
||||
|
||||
from funcy import rcompose, juxt, first, compose, second, chunks, identity
|
||||
from funcy import rcompose, first, compose, second, chunks, identity
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
|
||||
@ -29,46 +29,27 @@ def splat(f):
|
||||
return lambda x: f(*x)
|
||||
|
||||
|
||||
def inspect(x):
|
||||
x = list(x)
|
||||
import IPython
|
||||
|
||||
IPython.embed()
|
||||
return x
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
|
||||
extractor = get_extractor(**kwargs)
|
||||
extract = get_extractor(**kwargs)
|
||||
classifier = get_image_classifier(model_loader, model_identifier)
|
||||
formatter = get_formatter()
|
||||
reformat = get_formatter()
|
||||
|
||||
batcher = partial(chunks, batch_size)
|
||||
split = compose(splat(parallel(*map(lift, (first, second)))), tee)
|
||||
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
|
||||
join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), splat(zip))
|
||||
|
||||
classify = compose(chain.from_iterable, lift(classifier), batcher)
|
||||
|
||||
|
||||
def join_prediction_and_metadata(prd, mdt):
|
||||
return {"classification": prd, **mdt}
|
||||
|
||||
# +--classify--+
|
||||
# --extract image metadata paris-->--split--| |--zip-->-join-pairs-->format-->return
|
||||
# +--identity--+
|
||||
# +>--classify--v
|
||||
# --extract image metadata pairs-->--split--| |--join-->format
|
||||
# +>--identity--^
|
||||
|
||||
self.pipe = rcompose(
|
||||
extractor,
|
||||
tee,
|
||||
splat(parallel(*map(lift, (first, second)))),
|
||||
splat(parallel(classify, identity)),
|
||||
splat(zip),
|
||||
starlift(join_prediction_and_metadata),
|
||||
formatter,
|
||||
extract, # ... image-metadata-pairs as a stream
|
||||
split, # ... into an image stream and a metadata stream
|
||||
splat(parallel(classify, identity)), # ... process streams independently
|
||||
join, # ... the streams
|
||||
reformat, # ... the items
|
||||
)
|
||||
|
||||
def __call__(self, pdf: bytes, page_range: range = None):
|
||||
r = self.pipe(pdf, page_range=page_range)
|
||||
|
||||
r = list(r)
|
||||
|
||||
|
||||
return r
|
||||
yield from self.pipe(pdf, page_range=page_range)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user