RED-5107 add handling for 'broken' images: broken image parts are

replaced by blank images in the stitching process and completly broken
images are also replaced by blank images which are passed through and
are classified as 'other' with all_pased == False. This should be
changed in the future by introducing a new key to the response,
indicating that the image is not valid.
This commit is contained in:
Julius Unverfehrt 2022-08-30 10:54:26 +02:00
parent 7c6f9809bc
commit 265c61df1a
8 changed files with 43 additions and 29 deletions

View File

@ -50,20 +50,6 @@ class ParsablePDFImageExtractor(ImageExtractor):
yield from image_metadata_pairs
# def __preprocess(self, image_metadata_pair):
# image, metadata = image_metadata_pair
#
# try:
# image = self.__resize_and_convert(image)
# image_metadata_pair = ImageMetadataPair(image, metadata)
# except Exception as err:
# logger.warn(
# f"{err}: couldn't preprocess image [ page_idx: {metadata[Info.PAGE_IDX]}, x1: {metadata[Info.X1]}, y1: {metadata[Info.Y1]}, width: {metadata[Info.WIDTH]}, height: {metadata[Info.HEIGHT]} ]"
# )
# image_metadata_pair = None
#
# return image_metadata_pair
def extract_pages(doc, page_range):
page_range = range(page_range.start + 1, page_range.stop + 1)

View File

@ -1,6 +1,6 @@
from enum import Enum
from operator import itemgetter
from typing import Mapping, Iterable
from typing import Mapping, Iterable, Union
import numpy as np
from funcy import rcompose, rpartial
@ -27,7 +27,10 @@ class ProbabilityMapper(LabelMapper):
f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})."
)
def __map_array(self, probabilities: np.ndarray) -> dict:
def __map_array(self, probabilities: Union[np.ndarray, None]) -> Union[dict, None]:
if not isinstance(probabilities, np.ndarray) and not probabilities:
return None
self.__validate_array_label_format(probabilities)
cls2prob = dict(
sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True)

View File

@ -48,7 +48,6 @@ class Pipeline:
split = compose(star(parallel(*map(lift, (first, first, second)))), rpartial(tee, 3))
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
pairwise_apply = compose(star, parallel)
# TODO: use signal compress
join = compose(starlift(lambda prd, rpr, mdt: {"classification": prd, **mdt, "representation": rpr}), star(zip))
# />--classify--\

View File

@ -1,5 +1,3 @@
from funcy import rcompose
from image_prediction.utils import get_logger
logger = get_logger()
@ -9,11 +7,13 @@ class PredictionModelHandle:
"""Simplifies usage of ModelHandle instances for prediction purposes."""
def __init__(self, model_handle):
# TODO: extract signal
self.__predict = rcompose(model_handle.prep_images, model_handle.model.predict)
self.__prep_images = model_handle.prep_images
self.__predict = model_handle.model.predict
def predict(self, *args, **kwargs):
return self.__predict(*args, **kwargs)
tensor, valid_mask = self.__prep_images(*args, **kwargs)
predictions = self.__predict(tensor, **kwargs)
return [p if v else None for p, v in zip(predictions, valid_mask)]
def __call__(self, *args, **kwargs):
logger.debug("PredictionModelHandle.predict")

View File

@ -2,6 +2,11 @@ import abc
import numpy as np
import tensorflow as tf
from PIL import Image
from image_prediction.utils import get_logger
logger = get_logger()
class ModelWrapper(abc.ABC):
@ -30,13 +35,29 @@ class ModelWrapper(abc.ABC):
def __resize_and_convert(self, image):
return image.resize(self.input_shape[:-1]).convert("RGB")
def __monitored_resize_and_convert(self, image):
# RED-5170: fails if image is 'broken'
try:
image, valid = self.__resize_and_convert(image), True
except OSError as err:
image, valid = self.__handle_resize_exception(err)
except Exception as err:
image, valid = self.__handle_resize_exception(err)
return image, valid
def __handle_resize_exception(self, err):
logger.warn(f"{err}: couldn't resize image, replace and passthrough.")
image = Image.new("RGB", self.input_shape[:-1])
valid = False
return image, valid
def prep_images(self, images):
images = map(self.__resize_and_convert, images)
# TODO: signal, images = ...
images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images))
tensor = self.__images_to_tensor(images)
tensor = self.__preprocess_tensor(tensor)
return tensor
return tensor, valid_mask
@abc.abstractmethod
def __build(self, base_weights=None) -> tf.keras.models.Model:

View File

@ -191,6 +191,6 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
try: # RED-5170: fails if image is 'broken'
im_aggr.paste(im, box=box)
except Exception as err:
logger.warn(f"{err}: couldn't merge images, replace and passthrough. (page: {metadata[Info.PAGE_IDX]})")
logger.warn(f"{err}: couldn't merge image, replace and passthrough. (page: {metadata[Info.PAGE_IDX]})")
return im_aggr

View File

@ -35,10 +35,15 @@ def build_image_info(data: dict) -> dict:
width / height > CONFIG.filters.image_width_to_height_quotient.max
)
classification = data["classification"]
# FIXME: pass in fallback value for classification and introduce new key for image validness
classification = data["classification"] or "other"
representation = data["representation"]
min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence)
min_confidence_breached = (
bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence)
if data["classification"]
else True
)
image_info = {
"classification": classification,

View File

@ -102,7 +102,7 @@ def model_handle_mock(estimator_mock):
self.model = estimator_mock
def prep_images(self, batch):
return [None for _ in batch]
return [True for _ in batch], [None for _ in batch]
def predict(self, batch):
return [None for _ in batch]