refactoring
This commit is contained in:
parent
6c10b55ff8
commit
9a1446cccf
@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import chain, starmap
|
from itertools import chain, tee
|
||||||
|
|
||||||
from funcy import rcompose, juxt, first, compose, second, chunks, curry
|
from funcy import rcompose, juxt, first, compose, second, chunks, identity
|
||||||
|
|
||||||
from image_prediction.config import CONFIG
|
from image_prediction.config import CONFIG
|
||||||
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
|
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
|
||||||
@ -21,43 +21,54 @@ def load_pipeline(**kwargs):
|
|||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def parallel(*fs):
|
||||||
|
return lambda *args: (f(a) for f, a in zip(fs, args))
|
||||||
|
|
||||||
|
|
||||||
|
def splat(f):
|
||||||
|
return lambda x: f(*x)
|
||||||
|
|
||||||
|
|
||||||
|
def inspect(x):
|
||||||
|
x = list(x)
|
||||||
|
import IPython
|
||||||
|
|
||||||
|
IPython.embed()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
|
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
|
||||||
extractor = get_extractor(**kwargs)
|
extractor = get_extractor(**kwargs)
|
||||||
batcher = compose(lift(list), partial(chunks, batch_size))
|
|
||||||
classifier = get_image_classifier(model_loader, model_identifier)
|
classifier = get_image_classifier(model_loader, model_identifier)
|
||||||
|
|
||||||
left = compose(classifier, lift(first))
|
|
||||||
right = lift(second)
|
|
||||||
|
|
||||||
formatter = get_formatter()
|
formatter = get_formatter()
|
||||||
|
|
||||||
|
batcher = partial(chunks, batch_size)
|
||||||
|
|
||||||
|
classify = compose(chain.from_iterable, lift(classifier), batcher)
|
||||||
|
|
||||||
|
|
||||||
def join_prediction_and_metadata(prd, mdt):
|
def join_prediction_and_metadata(prd, mdt):
|
||||||
return {"classification": prd, **mdt}
|
return {"classification": prd, **mdt}
|
||||||
|
|
||||||
|
# +--classify--+
|
||||||
# --------
|
# --extract image metadata paris-->--split--| |--zip-->-join-pairs-->format-->return
|
||||||
# -- -- -- --
|
# +--identity--+
|
||||||
# == == == ==
|
|
||||||
# -- -- -- --
|
|
||||||
# --------
|
|
||||||
# --------
|
|
||||||
|
|
||||||
def inspect(x):
|
|
||||||
x = list(x)
|
|
||||||
import IPython
|
|
||||||
IPython.embed()
|
|
||||||
return x
|
|
||||||
|
|
||||||
self.pipe = rcompose(
|
self.pipe = rcompose(
|
||||||
extractor,
|
extractor,
|
||||||
batcher,
|
tee,
|
||||||
lift(juxt(left, right)),
|
splat(parallel(*map(lift, (first, second)))),
|
||||||
starlift(zip),
|
splat(parallel(classify, identity)),
|
||||||
lift(starlift(join_prediction_and_metadata)),
|
splat(zip),
|
||||||
chain.from_iterable,
|
starlift(join_prediction_and_metadata),
|
||||||
formatter,
|
formatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, pdf: bytes, page_range: range = None):
|
def __call__(self, pdf: bytes, page_range: range = None):
|
||||||
yield from self.pipe(pdf, page_range=page_range)
|
r = self.pipe(pdf, page_range=page_range)
|
||||||
|
|
||||||
|
r = list(r)
|
||||||
|
|
||||||
|
|
||||||
|
return r
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user