Matthias Bisping 72e785e3e3 refactoring
2022-04-21 19:43:39 +02:00

65 lines
1.9 KiB
Python

import os
from functools import partial
from itertools import chain, starmap
from funcy import rcompose, juxt, first, compose, second, chunks, curry
from image_prediction.config import CONFIG
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
from image_prediction.locations import MLRUNS_DIR
from image_prediction.utils.generic import lift, starlift
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
def load_pipeline(**kwargs):
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
model_identifier = CONFIG.service.run_id
pipeline = Pipeline(model_loader, model_identifier, progress_message="Processing document", **kwargs)
return pipeline
class Pipeline:
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
extractor = get_extractor(**kwargs)
batcher = compose(lift(list), partial(chunks, batch_size))
classifier = get_image_classifier(model_loader, model_identifier)
left = compose(classifier, lift(first))
right = lift(second)
formatter = get_formatter()
def join_prediction_and_metadata(prd, mdt):
return {"classification": prd, **mdt}
def process_batch(batch):
classifications, metadata = juxt(left, right)(batch)
return starmap(join_prediction_and_metadata, zip(classifications, metadata))
# --------
# -- -- -- --
# == == == ==
# -- -- -- --
# --------
# --------
def inspect(x):
x = list(x)
import IPython
IPython.embed()
return x
self.pipe = rcompose(
extractor,
batcher,
lift(process_batch),
chain.from_iterable,
formatter,
)
def __call__(self, pdf: bytes, page_range: range = None):
yield from self.pipe(pdf, page_range=page_range)