import abc import numpy as np import tensorflow as tf from PIL import Image from image_prediction.utils import get_logger logger = get_logger() class ModelWrapper(abc.ABC): def __init__(self, classes, base_weights_path=None, weights_path=None): self.__classes = classes self.model = self.__build(base_weights_path) self.model.load_weights(weights_path) @property @abc.abstractmethod def input_shape(self): raise NotImplementedError @property def classes(self): return self.__classes def prep_images(self, images): images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images)) tensor = self.__images_to_tensor(images) tensor = self.__preprocess_tensor(tensor) return tensor, valid_mask def __monitored_resize_and_convert(self, image): # RED-5170: fails if image is 'broken' try: image, valid = self.__resize_and_convert(image), True except (OSError, Exception) as err: image, valid = self.__handle_resize_exception(err) return image, valid def __resize_and_convert(self, image): return image.resize(self.input_shape[:-1]).convert("RGB") def __handle_resize_exception(self, err): logger.warn(f"{err}: Couldn't resize image, replace with blank image and passthrough.") image = Image.new("RGB", self.input_shape[:-1]) valid = False return image, valid @staticmethod def __images_to_tensor(images): return np.array(list(map(tf.keras.preprocessing.image.img_to_array, images))) @abc.abstractmethod def __preprocess_tensor(self, tensor): raise NotImplementedError @abc.abstractmethod def __build(self, base_weights=None) -> tf.keras.models.Model: raise NotImplementedError