Julius Unverfehrt 265c61df1a RED-5107 add handling for 'broken' images: broken image parts are
replaced by blank images in the stitching process and completly broken
images are also replaced by blank images which are passed through and
are classified as 'other' with all_pased == False. This should be
changed in the future by introducing a new key to the response,
indicating that the image is not valid.
2022-08-30 12:30:23 +02:00

65 lines
1.9 KiB
Python

import abc
import numpy as np
import tensorflow as tf
from PIL import Image
from image_prediction.utils import get_logger
logger = get_logger()
class ModelWrapper(abc.ABC):
def __init__(self, classes, base_weights_path=None, weights_path=None):
self.__classes = classes
self.model = self.__build(base_weights_path)
self.model.load_weights(weights_path)
@property
@abc.abstractmethod
def input_shape(self):
raise NotImplementedError
@property
def classes(self):
return self.__classes
@abc.abstractmethod
def __preprocess_tensor(self, tensor):
raise NotImplementedError
@staticmethod
def __images_to_tensor(images):
return np.array(list(map(tf.keras.preprocessing.image.img_to_array, images)))
def __resize_and_convert(self, image):
return image.resize(self.input_shape[:-1]).convert("RGB")
def __monitored_resize_and_convert(self, image):
# RED-5170: fails if image is 'broken'
try:
image, valid = self.__resize_and_convert(image), True
except OSError as err:
image, valid = self.__handle_resize_exception(err)
except Exception as err:
image, valid = self.__handle_resize_exception(err)
return image, valid
def __handle_resize_exception(self, err):
logger.warn(f"{err}: couldn't resize image, replace and passthrough.")
image = Image.new("RGB", self.input_shape[:-1])
valid = False
return image, valid
def prep_images(self, images):
images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images))
tensor = self.__images_to_tensor(images)
tensor = self.__preprocess_tensor(tensor)
return tensor, valid_mask
@abc.abstractmethod
def __build(self, base_weights=None) -> tf.keras.models.Model:
raise NotImplementedError