Julius Unverfehrt 265c61df1a RED-5107 add handling for 'broken' images: broken image parts are
replaced by blank images in the stitching process and completly broken
images are also replaced by blank images which are passed through and
are classified as 'other' with all_pased == False. This should be
changed in the future by introducing a new key to the response,
indicating that the image is not valid.
2022-08-30 12:30:23 +02:00

72 lines
2.3 KiB
Python

import os
from functools import partial
from itertools import chain, tee
from funcy import rcompose, first, compose, second, chunks, identity, rpartial
from tqdm import tqdm
from image_prediction.config import CONFIG
from image_prediction.default_objects import (
get_formatter,
get_mlflow_model_loader,
get_image_classifier,
get_extractor,
get_encoder,
)
from image_prediction.locations import MLRUNS_DIR
from image_prediction.utils.generic import lift, starlift
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
def load_pipeline(**kwargs):
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
model_identifier = CONFIG.service.mlflow_run_id
pipeline = Pipeline(model_loader, model_identifier, **kwargs)
return pipeline
def parallel(*fs):
return lambda *args: (f(a) for f, a in zip(fs, args))
def star(f):
return lambda x: f(*x)
class Pipeline:
def __init__(self, model_loader, model_identifier, batch_size=16, verbose=True, **kwargs):
self.verbose = verbose
extract = get_extractor(**kwargs)
classifier = get_image_classifier(model_loader, model_identifier)
reformat = get_formatter()
represent = get_encoder()
split = compose(star(parallel(*map(lift, (first, first, second)))), rpartial(tee, 3))
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
pairwise_apply = compose(star, parallel)
join = compose(starlift(lambda prd, rpr, mdt: {"classification": prd, **mdt, "representation": rpr}), star(zip))
# />--classify--\
# --extract-->--split--+->--encode---->+--join-->reformat
# \>--identity--/
self.pipe = rcompose(
extract, # ... image-metadata-pairs as a stream
split, # ... into an image stream and a metadata stream
pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise
join, # ... the streams by zipping
reformat, # ... the items
)
def __call__(self, pdf: bytes, page_range: range = None):
yield from tqdm(
self.pipe(pdf, page_range=page_range),
desc="Processing images from document",
unit=" images",
disable=not self.verbose,
)