Also sets image stitching tolerance default to one (pixel) and adds informative log of which settings are loaded when initializing the image classification pipeline.
75 lines
2.4 KiB
Python
75 lines
2.4 KiB
Python
import os
|
|
from functools import lru_cache, partial
|
|
from itertools import chain, tee
|
|
|
|
from funcy import rcompose, first, compose, second, chunks, identity, rpartial
|
|
from kn_utils.logging import logger
|
|
from tqdm import tqdm
|
|
|
|
from image_prediction.config import CONFIG
|
|
from image_prediction.default_objects import (
|
|
get_formatter,
|
|
get_mlflow_model_loader,
|
|
get_image_classifier,
|
|
get_extractor,
|
|
get_encoder,
|
|
)
|
|
from image_prediction.locations import MLRUNS_DIR
|
|
from image_prediction.utils.generic import lift, starlift
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def load_pipeline(**kwargs):
|
|
logger.info(f"Loading pipeline with kwargs: {kwargs}")
|
|
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
|
|
model_identifier = CONFIG.service.mlflow_run_id
|
|
|
|
pipeline = Pipeline(model_loader, model_identifier, **kwargs)
|
|
|
|
return pipeline
|
|
|
|
|
|
def parallel(*fs):
|
|
return lambda *args: (f(a) for f, a in zip(fs, args))
|
|
|
|
|
|
def star(f):
|
|
return lambda x: f(*x)
|
|
|
|
|
|
class Pipeline:
|
|
def __init__(self, model_loader, model_identifier, batch_size=16, verbose=False, **kwargs):
|
|
self.verbose = verbose
|
|
|
|
extract = get_extractor(**kwargs)
|
|
classifier = get_image_classifier(model_loader, model_identifier)
|
|
reformat = get_formatter()
|
|
represent = get_encoder()
|
|
|
|
split = compose(star(parallel(*map(lift, (first, first, second)))), rpartial(tee, 3))
|
|
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
|
|
pairwise_apply = compose(star, parallel)
|
|
join = compose(starlift(lambda prd, rpr, mdt: {"classification": prd, **mdt, "representation": rpr}), star(zip))
|
|
|
|
# />--classify--\
|
|
# --extract-->--split--+->--encode---->+--join-->reformat
|
|
# \>--identity--/
|
|
|
|
self.pipe = rcompose(
|
|
extract, # ... image-metadata-pairs as a stream
|
|
split, # ... into an image stream and a metadata stream
|
|
pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise
|
|
join, # ... the streams by zipping
|
|
reformat, # ... the items
|
|
)
|
|
|
|
def __call__(self, pdf: bytes, page_range: range = None):
|
|
yield from tqdm(
|
|
self.pipe(pdf, page_range=page_range),
|
|
desc="Processing images from document",
|
|
unit=" images",
|
|
disable=not self.verbose,
|
|
)
|