refactoring pipeline WIP

This commit is contained in:
Matthias Bisping 2022-04-20 18:30:41 +02:00
parent 120721f5f1
commit f036ee55e6
4 changed files with 53 additions and 12 deletions

View File

@ -24,10 +24,14 @@ class Classifier:
self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper) self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper)
def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]: def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]:
if not isinstance(batch, tuple) and batch.shape[0] == 0: # TODO: necessary?
if not batch:
return [] return []
return list(self.__pipe(batch)) if isinstance(batch, np.ndarray) and batch.shape[0] == 0:
return []
return list(self.__pipe(batch)) # TODO: list?
def __call__(self, batch: np.array) -> List[str]: def __call__(self, batch: np.array) -> List[str]:
logger.debug("Classifier.predict") logger.debug("Classifier.predict")

View File

@ -35,17 +35,17 @@ class ParsablePDFImageExtractor(ImageExtractor):
pages = extract_pages(self.doc, page_range) if page_range else self.doc pages = extract_pages(self.doc, page_range) if page_range else self.doc
pages = self.__maybe_show_progress(pages, page_range) # pages = self.__maybe_show_progress(pages, page_range)
image_metadata_pairs = chain.from_iterable(map(self.__process_images_on_page, pages)) image_metadata_pairs = chain.from_iterable(map(self.__process_images_on_page, pages))
yield from image_metadata_pairs yield from image_metadata_pairs
def __maybe_show_progress(self, iterable, page_range): # def __maybe_show_progress(self, iterable, page_range):
return self.__progressbar(page_range)(iterable) if self.verbose else iterable # return self.__progressbar(page_range)(iterable) if self.verbose else iterable
#
def __progressbar(self, page_range): # def __progressbar(self, page_range):
return rpartial(tqdm, desc=self.progress_message, total=len(page_range) if page_range else None) # return rpartial(tqdm, desc=self.progress_message, total=len(page_range) if page_range else None)
def __process_images_on_page(self, page: fitz.fitz.Page): def __process_images_on_page(self, page: fitz.fitz.Page):
images = get_images_on_page(self.doc, page) images = get_images_on_page(self.doc, page)

View File

@ -1,10 +1,12 @@
import os import os
from itertools import chain
from funcy import rcompose from funcy import rcompose, juxt, first, compose, second, chunks, curry
from image_prediction.config import CONFIG from image_prediction.config import CONFIG
from image_prediction.default_objects import get_extractor_classifier, get_formatter, get_mlflow_model_loader from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
from image_prediction.locations import MLRUNS_DIR from image_prediction.locations import MLRUNS_DIR
from image_prediction.utils.generic import lift, starlift
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@ -19,8 +21,37 @@ def load_pipeline(**kwargs):
class Pipeline: class Pipeline:
def __init__(self, model_loader, model_identifier, **kwargs): def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs):
self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier, **kwargs), get_formatter()) extractor = get_extractor(**kwargs)
batcher = lambda x: chunks(batch_size, x)
classifier = get_image_classifier(model_loader, model_identifier)
merger = lambda predictions, metadata: ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
formatter = get_formatter()
left = compose(classifier, lift(first))
right = lift(second)
# --------
# -- -- -- --
# == == == ==
# -- -- -- --
# --------
# --------
def inspect(x):
import IPython
IPython.embed()
return x
self.pipe = rcompose(
extractor,
batcher,
lift(list),
lift(juxt(left, right)),
starlift(merger),
chain.from_iterable,
formatter,
)
def __call__(self, pdf: bytes, page_range: range = None): def __call__(self, pdf: bytes, page_range: range = None):
yield from self.pipe(pdf, page_range=page_range) yield from self.pipe(pdf, page_range=page_range)

View File

@ -1,3 +1,5 @@
from itertools import starmap
from funcy import iterate, first, curry, map from funcy import iterate, first, curry, map
@ -7,3 +9,7 @@ def until(cond, func, *args, **kwargs):
def lift(fn): def lift(fn):
return curry(map)(fn) return curry(map)(fn)
def starlift(fn):
return curry(starmap)(fn)