diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 8914b01..9fe5b46 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -50,20 +50,6 @@ class ParsablePDFImageExtractor(ImageExtractor): yield from image_metadata_pairs - # def __preprocess(self, image_metadata_pair): - # image, metadata = image_metadata_pair - # - # try: - # image = self.__resize_and_convert(image) - # image_metadata_pair = ImageMetadataPair(image, metadata) - # except Exception as err: - # logger.warn( - # f"{err}: couldn't preprocess image [ page_idx: {metadata[Info.PAGE_IDX]}, x1: {metadata[Info.X1]}, y1: {metadata[Info.Y1]}, width: {metadata[Info.WIDTH]}, height: {metadata[Info.HEIGHT]} ]" - # ) - # image_metadata_pair = None - # - # return image_metadata_pair - def extract_pages(doc, page_range): page_range = range(page_range.start + 1, page_range.stop + 1) diff --git a/image_prediction/label_mapper/mappers/probability.py b/image_prediction/label_mapper/mappers/probability.py index b2a0e63..e84b4bf 100644 --- a/image_prediction/label_mapper/mappers/probability.py +++ b/image_prediction/label_mapper/mappers/probability.py @@ -1,6 +1,6 @@ from enum import Enum from operator import itemgetter -from typing import Mapping, Iterable +from typing import Mapping, Iterable, Union import numpy as np from funcy import rcompose, rpartial @@ -27,7 +27,10 @@ class ProbabilityMapper(LabelMapper): f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})." ) - def __map_array(self, probabilities: np.ndarray) -> dict: + def __map_array(self, probabilities: Union[np.ndarray, None]) -> Union[dict, None]: + if not isinstance(probabilities, np.ndarray) and not probabilities: + return None + self.__validate_array_label_format(probabilities) cls2prob = dict( sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True) diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index e8114b8..6d29ac7 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -48,7 +48,6 @@ class Pipeline: split = compose(star(parallel(*map(lift, (first, first, second)))), rpartial(tee, 3)) classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size)) pairwise_apply = compose(star, parallel) - # TODO: use signal compress join = compose(starlift(lambda prd, rpr, mdt: {"classification": prd, **mdt, "representation": rpr}), star(zip)) # />--classify--\ diff --git a/image_prediction/redai_adapter/model.py b/image_prediction/redai_adapter/model.py index fabcb78..a0ac3a8 100644 --- a/image_prediction/redai_adapter/model.py +++ b/image_prediction/redai_adapter/model.py @@ -1,5 +1,3 @@ -from funcy import rcompose - from image_prediction.utils import get_logger logger = get_logger() @@ -9,11 +7,13 @@ class PredictionModelHandle: """Simplifies usage of ModelHandle instances for prediction purposes.""" def __init__(self, model_handle): - # TODO: extract signal - self.__predict = rcompose(model_handle.prep_images, model_handle.model.predict) + self.__prep_images = model_handle.prep_images + self.__predict = model_handle.model.predict def predict(self, *args, **kwargs): - return self.__predict(*args, **kwargs) + tensor, valid_mask = self.__prep_images(*args, **kwargs) + predictions = self.__predict(tensor, **kwargs) + return [p if v else None for p, v in zip(predictions, valid_mask)] def __call__(self, *args, **kwargs): logger.debug("PredictionModelHandle.predict") diff --git a/image_prediction/redai_adapter/model_wrapper.py b/image_prediction/redai_adapter/model_wrapper.py index 89aa49b..de9d543 100644 --- a/image_prediction/redai_adapter/model_wrapper.py +++ b/image_prediction/redai_adapter/model_wrapper.py @@ -2,6 +2,11 @@ import abc import numpy as np import tensorflow as tf +from PIL import Image + +from image_prediction.utils import get_logger + +logger = get_logger() class ModelWrapper(abc.ABC): @@ -30,13 +35,29 @@ class ModelWrapper(abc.ABC): def __resize_and_convert(self, image): return image.resize(self.input_shape[:-1]).convert("RGB") + def __monitored_resize_and_convert(self, image): + # RED-5170: fails if image is 'broken' + try: + image, valid = self.__resize_and_convert(image), True + except OSError as err: + image, valid = self.__handle_resize_exception(err) + except Exception as err: + image, valid = self.__handle_resize_exception(err) + + return image, valid + + def __handle_resize_exception(self, err): + logger.warn(f"{err}: couldn't resize image, replace and passthrough.") + image = Image.new("RGB", self.input_shape[:-1]) + valid = False + return image, valid + def prep_images(self, images): - images = map(self.__resize_and_convert, images) - # TODO: signal, images = ... + images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images)) tensor = self.__images_to_tensor(images) tensor = self.__preprocess_tensor(tensor) - return tensor + return tensor, valid_mask @abc.abstractmethod def __build(self, base_weights=None) -> tf.keras.models.Model: diff --git a/image_prediction/stitching/merging.py b/image_prediction/stitching/merging.py index 96da32b..fb96665 100644 --- a/image_prediction/stitching/merging.py +++ b/image_prediction/stitching/merging.py @@ -191,6 +191,6 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis): try: # RED-5170: fails if image is 'broken' im_aggr.paste(im, box=box) except Exception as err: - logger.warn(f"{err}: couldn't merge images, replace and passthrough. (page: {metadata[Info.PAGE_IDX]})") + logger.warn(f"{err}: couldn't merge image, replace and passthrough. (page: {metadata[Info.PAGE_IDX]})") return im_aggr diff --git a/image_prediction/transformer/transformers/response.py b/image_prediction/transformer/transformers/response.py index 3e35104..a3dc601 100644 --- a/image_prediction/transformer/transformers/response.py +++ b/image_prediction/transformer/transformers/response.py @@ -35,10 +35,15 @@ def build_image_info(data: dict) -> dict: width / height > CONFIG.filters.image_width_to_height_quotient.max ) - classification = data["classification"] + # FIXME: pass in fallback value for classification and introduce new key for image validness + classification = data["classification"] or "other" representation = data["representation"] - min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence) + min_confidence_breached = ( + bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence) + if data["classification"] + else True + ) image_info = { "classification": classification, diff --git a/test/fixtures/model.py b/test/fixtures/model.py index 729d234..812da4b 100644 --- a/test/fixtures/model.py +++ b/test/fixtures/model.py @@ -102,7 +102,7 @@ def model_handle_mock(estimator_mock): self.model = estimator_mock def prep_images(self, batch): - return [None for _ in batch] + return [True for _ in batch], [None for _ in batch] def predict(self, batch): return [None for _ in batch]