diff --git a/image_prediction/classifier/classifier.py b/image_prediction/classifier/classifier.py index b537aa1..4500572 100644 --- a/image_prediction/classifier/classifier.py +++ b/image_prediction/classifier/classifier.py @@ -24,10 +24,14 @@ class Classifier: self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper) def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]: - if not isinstance(batch, tuple) and batch.shape[0] == 0: + # TODO: necessary? + if not batch: return [] - return list(self.__pipe(batch)) + if isinstance(batch, np.ndarray) and batch.shape[0] == 0: + return [] + + return list(self.__pipe(batch)) # TODO: list? def __call__(self, batch: np.array) -> List[str]: logger.debug("Classifier.predict") diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 3448f6f..f98348d 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -35,17 +35,17 @@ class ParsablePDFImageExtractor(ImageExtractor): pages = extract_pages(self.doc, page_range) if page_range else self.doc - pages = self.__maybe_show_progress(pages, page_range) + # pages = self.__maybe_show_progress(pages, page_range) image_metadata_pairs = chain.from_iterable(map(self.__process_images_on_page, pages)) yield from image_metadata_pairs - def __maybe_show_progress(self, iterable, page_range): - return self.__progressbar(page_range)(iterable) if self.verbose else iterable - - def __progressbar(self, page_range): - return rpartial(tqdm, desc=self.progress_message, total=len(page_range) if page_range else None) + # def __maybe_show_progress(self, iterable, page_range): + # return self.__progressbar(page_range)(iterable) if self.verbose else iterable + # + # def __progressbar(self, page_range): + # return rpartial(tqdm, desc=self.progress_message, total=len(page_range) if page_range else None) def __process_images_on_page(self, page: fitz.fitz.Page): images = get_images_on_page(self.doc, page) diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index fad4145..7392d72 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -1,10 +1,12 @@ import os +from itertools import chain -from funcy import rcompose +from funcy import rcompose, juxt, first, compose, second, chunks, curry from image_prediction.config import CONFIG -from image_prediction.default_objects import get_extractor_classifier, get_formatter, get_mlflow_model_loader +from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor from image_prediction.locations import MLRUNS_DIR +from image_prediction.utils.generic import lift, starlift os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -19,8 +21,37 @@ def load_pipeline(**kwargs): class Pipeline: - def __init__(self, model_loader, model_identifier, **kwargs): - self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier, **kwargs), get_formatter()) + def __init__(self, model_loader, model_identifier, batch_size=16, **kwargs): + extractor = get_extractor(**kwargs) + batcher = lambda x: chunks(batch_size, x) + classifier = get_image_classifier(model_loader, model_identifier) + merger = lambda predictions, metadata: ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata)) + formatter = get_formatter() + + left = compose(classifier, lift(first)) + right = lift(second) + + # -------- + # -- -- -- -- + # == == == == + # -- -- -- -- + # -------- + # -------- + + def inspect(x): + import IPython + IPython.embed() + return x + + self.pipe = rcompose( + extractor, + batcher, + lift(list), + lift(juxt(left, right)), + starlift(merger), + chain.from_iterable, + formatter, + ) def __call__(self, pdf: bytes, page_range: range = None): yield from self.pipe(pdf, page_range=page_range) diff --git a/image_prediction/utils/generic.py b/image_prediction/utils/generic.py index 9b25640..de71a5c 100644 --- a/image_prediction/utils/generic.py +++ b/image_prediction/utils/generic.py @@ -1,3 +1,5 @@ +from itertools import starmap + from funcy import iterate, first, curry, map @@ -7,3 +9,7 @@ def until(cond, func, *args, **kwargs): def lift(fn): return curry(map)(fn) + + +def starlift(fn): + return curry(starmap)(fn)