From 265c61df1aa1436bac946bda28ed33fd884df8c6 Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Tue, 30 Aug 2022 10:54:26 +0200 Subject: [PATCH] RED-5107 add handling for 'broken' images: broken image parts are replaced by blank images in the stitching process and completly broken images are also replaced by blank images which are passed through and are classified as 'other' with all_pased == False. This should be changed in the future by introducing a new key to the response, indicating that the image is not valid. --- .../image_extractor/extractors/parsable.py | 14 ---------- .../label_mapper/mappers/probability.py | 7 +++-- image_prediction/pipeline.py | 1 - image_prediction/redai_adapter/model.py | 10 +++---- .../redai_adapter/model_wrapper.py | 27 ++++++++++++++++--- image_prediction/stitching/merging.py | 2 +- .../transformer/transformers/response.py | 9 +++++-- test/fixtures/model.py | 2 +- 8 files changed, 43 insertions(+), 29 deletions(-) diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 8914b01..9fe5b46 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -50,20 +50,6 @@ class ParsablePDFImageExtractor(ImageExtractor): yield from image_metadata_pairs - # def __preprocess(self, image_metadata_pair): - # image, metadata = image_metadata_pair - # - # try: - # image = self.__resize_and_convert(image) - # image_metadata_pair = ImageMetadataPair(image, metadata) - # except Exception as err: - # logger.warn( - # f"{err}: couldn't preprocess image [ page_idx: {metadata[Info.PAGE_IDX]}, x1: {metadata[Info.X1]}, y1: {metadata[Info.Y1]}, width: {metadata[Info.WIDTH]}, height: {metadata[Info.HEIGHT]} ]" - # ) - # image_metadata_pair = None - # - # return image_metadata_pair - def extract_pages(doc, page_range): page_range = range(page_range.start + 1, page_range.stop + 1) diff --git a/image_prediction/label_mapper/mappers/probability.py b/image_prediction/label_mapper/mappers/probability.py index b2a0e63..e84b4bf 100644 --- a/image_prediction/label_mapper/mappers/probability.py +++ b/image_prediction/label_mapper/mappers/probability.py @@ -1,6 +1,6 @@ from enum import Enum from operator import itemgetter -from typing import Mapping, Iterable +from typing import Mapping, Iterable, Union import numpy as np from funcy import rcompose, rpartial @@ -27,7 +27,10 @@ class ProbabilityMapper(LabelMapper): f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})." ) - def __map_array(self, probabilities: np.ndarray) -> dict: + def __map_array(self, probabilities: Union[np.ndarray, None]) -> Union[dict, None]: + if not isinstance(probabilities, np.ndarray) and not probabilities: + return None + self.__validate_array_label_format(probabilities) cls2prob = dict( sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True) diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index e8114b8..6d29ac7 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -48,7 +48,6 @@ class Pipeline: split = compose(star(parallel(*map(lift, (first, first, second)))), rpartial(tee, 3)) classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size)) pairwise_apply = compose(star, parallel) - # TODO: use signal compress join = compose(starlift(lambda prd, rpr, mdt: {"classification": prd, **mdt, "representation": rpr}), star(zip)) # />--classify--\ diff --git a/image_prediction/redai_adapter/model.py b/image_prediction/redai_adapter/model.py index fabcb78..a0ac3a8 100644 --- a/image_prediction/redai_adapter/model.py +++ b/image_prediction/redai_adapter/model.py @@ -1,5 +1,3 @@ -from funcy import rcompose - from image_prediction.utils import get_logger logger = get_logger() @@ -9,11 +7,13 @@ class PredictionModelHandle: """Simplifies usage of ModelHandle instances for prediction purposes.""" def __init__(self, model_handle): - # TODO: extract signal - self.__predict = rcompose(model_handle.prep_images, model_handle.model.predict) + self.__prep_images = model_handle.prep_images + self.__predict = model_handle.model.predict def predict(self, *args, **kwargs): - return self.__predict(*args, **kwargs) + tensor, valid_mask = self.__prep_images(*args, **kwargs) + predictions = self.__predict(tensor, **kwargs) + return [p if v else None for p, v in zip(predictions, valid_mask)] def __call__(self, *args, **kwargs): logger.debug("PredictionModelHandle.predict") diff --git a/image_prediction/redai_adapter/model_wrapper.py b/image_prediction/redai_adapter/model_wrapper.py index 89aa49b..de9d543 100644 --- a/image_prediction/redai_adapter/model_wrapper.py +++ b/image_prediction/redai_adapter/model_wrapper.py @@ -2,6 +2,11 @@ import abc import numpy as np import tensorflow as tf +from PIL import Image + +from image_prediction.utils import get_logger + +logger = get_logger() class ModelWrapper(abc.ABC): @@ -30,13 +35,29 @@ class ModelWrapper(abc.ABC): def __resize_and_convert(self, image): return image.resize(self.input_shape[:-1]).convert("RGB") + def __monitored_resize_and_convert(self, image): + # RED-5170: fails if image is 'broken' + try: + image, valid = self.__resize_and_convert(image), True + except OSError as err: + image, valid = self.__handle_resize_exception(err) + except Exception as err: + image, valid = self.__handle_resize_exception(err) + + return image, valid + + def __handle_resize_exception(self, err): + logger.warn(f"{err}: couldn't resize image, replace and passthrough.") + image = Image.new("RGB", self.input_shape[:-1]) + valid = False + return image, valid + def prep_images(self, images): - images = map(self.__resize_and_convert, images) - # TODO: signal, images = ... + images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images)) tensor = self.__images_to_tensor(images) tensor = self.__preprocess_tensor(tensor) - return tensor + return tensor, valid_mask @abc.abstractmethod def __build(self, base_weights=None) -> tf.keras.models.Model: diff --git a/image_prediction/stitching/merging.py b/image_prediction/stitching/merging.py index 96da32b..fb96665 100644 --- a/image_prediction/stitching/merging.py +++ b/image_prediction/stitching/merging.py @@ -191,6 +191,6 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis): try: # RED-5170: fails if image is 'broken' im_aggr.paste(im, box=box) except Exception as err: - logger.warn(f"{err}: couldn't merge images, replace and passthrough. (page: {metadata[Info.PAGE_IDX]})") + logger.warn(f"{err}: couldn't merge image, replace and passthrough. (page: {metadata[Info.PAGE_IDX]})") return im_aggr diff --git a/image_prediction/transformer/transformers/response.py b/image_prediction/transformer/transformers/response.py index 3e35104..a3dc601 100644 --- a/image_prediction/transformer/transformers/response.py +++ b/image_prediction/transformer/transformers/response.py @@ -35,10 +35,15 @@ def build_image_info(data: dict) -> dict: width / height > CONFIG.filters.image_width_to_height_quotient.max ) - classification = data["classification"] + # FIXME: pass in fallback value for classification and introduce new key for image validness + classification = data["classification"] or "other" representation = data["representation"] - min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence) + min_confidence_breached = ( + bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence) + if data["classification"] + else True + ) image_info = { "classification": classification, diff --git a/test/fixtures/model.py b/test/fixtures/model.py index 729d234..812da4b 100644 --- a/test/fixtures/model.py +++ b/test/fixtures/model.py @@ -102,7 +102,7 @@ def model_handle_mock(estimator_mock): self.model = estimator_mock def prep_images(self, batch): - return [None for _ in batch] + return [True for _ in batch], [None for _ in batch] def predict(self, batch): return [None for _ in batch]