Matthias Bisping 75748a1d82 refactoring
2022-04-25 11:19:26 +02:00

65 lines
2.2 KiB
Python

import os
from functools import partial
from itertools import chain, tee
from funcy import rcompose, first, compose, second, chunks, identity
from tqdm import tqdm
from image_prediction.config import CONFIG
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
from image_prediction.locations import MLRUNS_DIR
from image_prediction.utils.generic import lift, starlift
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
def load_pipeline(**kwargs):
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
model_identifier = CONFIG.service.run_id
pipeline = Pipeline(model_loader, model_identifier, **kwargs)
return pipeline
def parallel(*fs):
return lambda *args: (f(a) for f, a in zip(fs, args))
def star(f):
return lambda x: f(*x)
class Pipeline:
def __init__(self, model_loader, model_identifier, batch_size=16, verbose=True, **kwargs):
self.verbose = verbose
extract = get_extractor(**kwargs)
classifier = get_image_classifier(model_loader, model_identifier)
reformat = get_formatter()
split = compose(star(parallel(*map(lift, (first, second)))), tee)
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
pairwise_apply = compose(star, parallel)
join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), star(zip))
# +>--classify--v
# --extract-->--split--| |--join-->reformat
# +>--identity--^
self.pipe = rcompose(
extract, # ... image-metadata-pairs as a stream
split, # ... into an image stream and a metadata stream
pairwise_apply(classify, identity), # ... apply functions to the streams pairwise
join, # ... the streams by zipping
reformat, # ... the items
)
def __call__(self, pdf: bytes, page_range: range = None):
yield from tqdm(
self.pipe(pdf, page_range=page_range),
desc="Processing images from document",
unit=" images",
disable=not self.verbose,
)