refactoring pipeline WIP

This commit is contained in:
Matthias Bisping 2022-04-20 18:30:41 +02:00
parent 120721f5f1
commit f036ee55e6
4 changed files with 53 additions and 12 deletions

View File

@ -24,10 +24,14 @@ class Classifier:
self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper)
def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]:
if not isinstance(batch, tuple) and batch.shape[0] == 0:
# TODO: necessary?
if not batch:
return []
return list(self.__pipe(batch))
if isinstance(batch, np.ndarray) and batch.shape[0] == 0:
return []
return list(self.__pipe(batch)) # TODO: list?
def __call__(self, batch: np.array) -> List[str]:
logger.debug("Classifier.predict")

View File

@ -35,17 +35,17 @@ class ParsablePDFImageExtractor(ImageExtractor):
pages = extract_pages(self.doc, page_range) if page_range else self.doc
pages = self.__maybe_show_progress(pages, page_range)
# pages = self.__maybe_show_progress(pages, page_range)
image_metadata_pairs = chain.from_iterable(map(self.__process_images_on_page, pages))
yield from image_metadata_pairs
def __maybe_show_progress(self, iterable, page_range):
return self.__progressbar(page_range)(iterable) if self.verbose else iterable
def __progressbar(self, page_range):
return rpartial(tqdm, desc=self.progress_message, total=len(page_range) if page_range else None)
# def __maybe_show_progress(self, iterable, page_range):
# return self.__progressbar(page_range)(iterable) if self.verbose else iterable
#
# def __progressbar(self, page_range):
# return rpartial(tqdm, desc=self.progress_message, total=len(page_range) if page_range else None)
def __process_images_on_page(self, page: fitz.fitz.Page):
images = get_images_on_page(self.doc, page)

View File

@ -1,10 +1,12 @@
import os
from itertools import chain
from funcy import rcompose
from funcy import rcompose, juxt, first, compose, second, chunks, curry
from image_prediction.config import CONFIG
from image_prediction.default_objects import get_extractor_classifier, get_formatter, get_mlflow_model_loader
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
from image_prediction.locations import MLRUNS_DIR
from image_prediction.utils.generic import lift, starlift
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@ -19,8 +21,37 @@ def load_pipeline(**kwargs):
class Pipeline:
def __init__(self, model_loader, model_identifier, **kwargs):
self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier, **kwargs), get_formatter())
def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
extractor = get_extractor(**kwargs)
batcher = lambda x: chunks(batch_size, x)
classifier = get_image_classifier(model_loader, model_identifier)
merger = lambda predictions, metadata: ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
formatter = get_formatter()
left = compose(classifier, lift(first))
right = lift(second)
# --------
# -- -- -- --
# == == == ==
# -- -- -- --
# --------
# --------
def inspect(x):
import IPython
IPython.embed()
return x
self.pipe = rcompose(
extractor,
batcher,
lift(list),
lift(juxt(left, right)),
starlift(merger),
chain.from_iterable,
formatter,
)
def __call__(self, pdf: bytes, page_range: range = None):
yield from self.pipe(pdf, page_range=page_range)

View File

@ -1,3 +1,5 @@
from itertools import starmap
from funcy import iterate, first, curry, map
@ -7,3 +9,7 @@ def until(cond, func, *args, **kwargs):
def lift(fn):
return curry(map)(fn)
def starlift(fn):
return curry(starmap)(fn)