Julius Unverfehrt 98dc001123 revert adhoc figure detection changes
- revert pipeline and serve logic to pre figure detection data for image
extraction changes: figure detection data as input not supported for now
2023-01-30 12:41:22 +01:00

73 lines
2.3 KiB
Python

import os
from functools import partial
from itertools import chain, tee
from typing import Iterable
from funcy import rcompose, first, compose, second, chunks, identity, rpartial
from tqdm import tqdm
from image_prediction.config import CONFIG
from image_prediction.default_objects import (
get_formatter,
get_mlflow_model_loader,
get_image_classifier,
get_extractor,
get_encoder,
)
from image_prediction.locations import MLRUNS_DIR
from image_prediction.utils.generic import lift, starlift
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
def load_pipeline(**kwargs):
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
model_identifier = CONFIG.service.mlflow_run_id
pipeline = Pipeline(model_loader, model_identifier, **kwargs)
return pipeline
def parallel(*fs):
return lambda *args: (f(a) for f, a in zip(fs, args))
def star(f):
return lambda x: f(*x)
class Pipeline:
def __init__(self, model_loader, model_identifier, batch_size=16, verbose=True, **kwargs):
self.verbose = verbose
extract = get_extractor(**kwargs)
classifier = get_image_classifier(model_loader, model_identifier)
reformat = get_formatter()
represent = get_encoder()
split = compose(star(parallel(*map(lift, (first, first, second)))), rpartial(tee, 3))
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
pairwise_apply = compose(star, parallel)
join = compose(starlift(lambda prd, rpr, mdt: {"classification": prd, **mdt, "representation": rpr}), star(zip))
# />--classify--\
# --extract-->--split--+->--encode---->+--join-->reformat
# \>--identity--/
self.pipe = rcompose(
extract, # ... image-metadata-pairs as a stream
split, # ... into an image stream and a metadata stream
pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise
join, # ... the streams by zipping
reformat, # ... the items
)
def __call__(self, pdf: bytes, page_range: range = None):
yield from tqdm(
self.pipe(pdf, page_range=page_range),
desc="Processing images from document",
unit=" images",
disable=not self.verbose,
)