64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
import os
|
|
from functools import partial
|
|
from itertools import chain, starmap
|
|
|
|
from funcy import rcompose, juxt, first, compose, second, chunks, curry
|
|
|
|
from image_prediction.config import CONFIG
|
|
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
|
|
from image_prediction.locations import MLRUNS_DIR
|
|
from image_prediction.utils.generic import lift, starlift
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
|
|
|
def load_pipeline(**kwargs):
|
|
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
|
|
model_identifier = CONFIG.service.run_id
|
|
|
|
pipeline = Pipeline(model_loader, model_identifier, progress_message="Processing document", **kwargs)
|
|
|
|
return pipeline
|
|
|
|
|
|
class Pipeline:
|
|
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
|
|
extractor = get_extractor(**kwargs)
|
|
batcher = compose(lift(list), partial(chunks, batch_size))
|
|
classifier = get_image_classifier(model_loader, model_identifier)
|
|
|
|
left = compose(classifier, lift(first))
|
|
right = lift(second)
|
|
|
|
formatter = get_formatter()
|
|
|
|
def join_prediction_and_metadata(prd, mdt):
|
|
return {"classification": prd, **mdt}
|
|
|
|
|
|
# --------
|
|
# -- -- -- --
|
|
# == == == ==
|
|
# -- -- -- --
|
|
# --------
|
|
# --------
|
|
|
|
def inspect(x):
|
|
x = list(x)
|
|
import IPython
|
|
IPython.embed()
|
|
return x
|
|
|
|
self.pipe = rcompose(
|
|
extractor,
|
|
batcher,
|
|
lift(juxt(left, right)),
|
|
starlift(zip),
|
|
lift(starlift(join_prediction_and_metadata)),
|
|
chain.from_iterable,
|
|
formatter,
|
|
)
|
|
|
|
def __call__(self, pdf: bytes, page_range: range = None):
|
|
yield from self.pipe(pdf, page_range=page_range)
|