RED-5107 add handling for 'broken' images: broken image parts are
replaced by blank images in the stitching process and completly broken images are also replaced by blank images which are passed through and are classified as 'other' with all_pased == False. This should be changed in the future by introducing a new key to the response, indicating that the image is not valid.
This commit is contained in:
parent
7c6f9809bc
commit
265c61df1a
@ -50,20 +50,6 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
|
||||
yield from image_metadata_pairs
|
||||
|
||||
# def __preprocess(self, image_metadata_pair):
|
||||
# image, metadata = image_metadata_pair
|
||||
#
|
||||
# try:
|
||||
# image = self.__resize_and_convert(image)
|
||||
# image_metadata_pair = ImageMetadataPair(image, metadata)
|
||||
# except Exception as err:
|
||||
# logger.warn(
|
||||
# f"{err}: couldn't preprocess image [ page_idx: {metadata[Info.PAGE_IDX]}, x1: {metadata[Info.X1]}, y1: {metadata[Info.Y1]}, width: {metadata[Info.WIDTH]}, height: {metadata[Info.HEIGHT]} ]"
|
||||
# )
|
||||
# image_metadata_pair = None
|
||||
#
|
||||
# return image_metadata_pair
|
||||
|
||||
|
||||
def extract_pages(doc, page_range):
|
||||
page_range = range(page_range.start + 1, page_range.stop + 1)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
from operator import itemgetter
|
||||
from typing import Mapping, Iterable
|
||||
from typing import Mapping, Iterable, Union
|
||||
|
||||
import numpy as np
|
||||
from funcy import rcompose, rpartial
|
||||
@ -27,7 +27,10 @@ class ProbabilityMapper(LabelMapper):
|
||||
f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})."
|
||||
)
|
||||
|
||||
def __map_array(self, probabilities: np.ndarray) -> dict:
|
||||
def __map_array(self, probabilities: Union[np.ndarray, None]) -> Union[dict, None]:
|
||||
if not isinstance(probabilities, np.ndarray) and not probabilities:
|
||||
return None
|
||||
|
||||
self.__validate_array_label_format(probabilities)
|
||||
cls2prob = dict(
|
||||
sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True)
|
||||
|
||||
@ -48,7 +48,6 @@ class Pipeline:
|
||||
split = compose(star(parallel(*map(lift, (first, first, second)))), rpartial(tee, 3))
|
||||
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
|
||||
pairwise_apply = compose(star, parallel)
|
||||
# TODO: use signal compress
|
||||
join = compose(starlift(lambda prd, rpr, mdt: {"classification": prd, **mdt, "representation": rpr}), star(zip))
|
||||
|
||||
# />--classify--\
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from funcy import rcompose
|
||||
|
||||
from image_prediction.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@ -9,11 +7,13 @@ class PredictionModelHandle:
|
||||
"""Simplifies usage of ModelHandle instances for prediction purposes."""
|
||||
|
||||
def __init__(self, model_handle):
|
||||
# TODO: extract signal
|
||||
self.__predict = rcompose(model_handle.prep_images, model_handle.model.predict)
|
||||
self.__prep_images = model_handle.prep_images
|
||||
self.__predict = model_handle.model.predict
|
||||
|
||||
def predict(self, *args, **kwargs):
|
||||
return self.__predict(*args, **kwargs)
|
||||
tensor, valid_mask = self.__prep_images(*args, **kwargs)
|
||||
predictions = self.__predict(tensor, **kwargs)
|
||||
return [p if v else None for p, v in zip(predictions, valid_mask)]
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
logger.debug("PredictionModelHandle.predict")
|
||||
|
||||
@ -2,6 +2,11 @@ import abc
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from PIL import Image
|
||||
|
||||
from image_prediction.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class ModelWrapper(abc.ABC):
|
||||
@ -30,13 +35,29 @@ class ModelWrapper(abc.ABC):
|
||||
def __resize_and_convert(self, image):
|
||||
return image.resize(self.input_shape[:-1]).convert("RGB")
|
||||
|
||||
def __monitored_resize_and_convert(self, image):
|
||||
# RED-5170: fails if image is 'broken'
|
||||
try:
|
||||
image, valid = self.__resize_and_convert(image), True
|
||||
except OSError as err:
|
||||
image, valid = self.__handle_resize_exception(err)
|
||||
except Exception as err:
|
||||
image, valid = self.__handle_resize_exception(err)
|
||||
|
||||
return image, valid
|
||||
|
||||
def __handle_resize_exception(self, err):
|
||||
logger.warn(f"{err}: couldn't resize image, replace and passthrough.")
|
||||
image = Image.new("RGB", self.input_shape[:-1])
|
||||
valid = False
|
||||
return image, valid
|
||||
|
||||
def prep_images(self, images):
|
||||
images = map(self.__resize_and_convert, images)
|
||||
# TODO: signal, images = ...
|
||||
images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images))
|
||||
tensor = self.__images_to_tensor(images)
|
||||
tensor = self.__preprocess_tensor(tensor)
|
||||
|
||||
return tensor
|
||||
return tensor, valid_mask
|
||||
|
||||
@abc.abstractmethod
|
||||
def __build(self, base_weights=None) -> tf.keras.models.Model:
|
||||
|
||||
@ -191,6 +191,6 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
||||
try: # RED-5170: fails if image is 'broken'
|
||||
im_aggr.paste(im, box=box)
|
||||
except Exception as err:
|
||||
logger.warn(f"{err}: couldn't merge images, replace and passthrough. (page: {metadata[Info.PAGE_IDX]})")
|
||||
logger.warn(f"{err}: couldn't merge image, replace and passthrough. (page: {metadata[Info.PAGE_IDX]})")
|
||||
|
||||
return im_aggr
|
||||
|
||||
@ -35,10 +35,15 @@ def build_image_info(data: dict) -> dict:
|
||||
width / height > CONFIG.filters.image_width_to_height_quotient.max
|
||||
)
|
||||
|
||||
classification = data["classification"]
|
||||
# FIXME: pass in fallback value for classification and introduce new key for image validness
|
||||
classification = data["classification"] or "other"
|
||||
representation = data["representation"]
|
||||
|
||||
min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence)
|
||||
min_confidence_breached = (
|
||||
bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence)
|
||||
if data["classification"]
|
||||
else True
|
||||
)
|
||||
|
||||
image_info = {
|
||||
"classification": classification,
|
||||
|
||||
2
test/fixtures/model.py
vendored
2
test/fixtures/model.py
vendored
@ -102,7 +102,7 @@ def model_handle_mock(estimator_mock):
|
||||
self.model = estimator_mock
|
||||
|
||||
def prep_images(self, batch):
|
||||
return [None for _ in batch]
|
||||
return [True for _ in batch], [None for _ in batch]
|
||||
|
||||
def predict(self, batch):
|
||||
return [None for _ in batch]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user