adjust caller hierarchy

This commit is contained in:
Julius Unverfehrt 2022-08-30 13:17:21 +02:00
parent 265c61df1a
commit 55a5dd11d6

View File

@ -24,16 +24,12 @@ class ModelWrapper(abc.ABC):
def classes(self):
return self.__classes
@abc.abstractmethod
def __preprocess_tensor(self, tensor):
raise NotImplementedError
def prep_images(self, images):
images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images))
tensor = self.__images_to_tensor(images)
tensor = self.__preprocess_tensor(tensor)
@staticmethod
def __images_to_tensor(images):
return np.array(list(map(tf.keras.preprocessing.image.img_to_array, images)))
def __resize_and_convert(self, image):
return image.resize(self.input_shape[:-1]).convert("RGB")
return tensor, valid_mask
def __monitored_resize_and_convert(self, image):
# RED-5170: fails if image is 'broken'
@ -46,18 +42,22 @@ class ModelWrapper(abc.ABC):
return image, valid
def __resize_and_convert(self, image):
return image.resize(self.input_shape[:-1]).convert("RGB")
def __handle_resize_exception(self, err):
logger.warn(f"{err}: couldn't resize image, replace and passthrough.")
image = Image.new("RGB", self.input_shape[:-1])
valid = False
return image, valid
def prep_images(self, images):
images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images))
tensor = self.__images_to_tensor(images)
tensor = self.__preprocess_tensor(tensor)
@staticmethod
def __images_to_tensor(images):
return np.array(list(map(tf.keras.preprocessing.image.img_to_array, images)))
return tensor, valid_mask
@abc.abstractmethod
def __preprocess_tensor(self, tensor):
raise NotImplementedError
@abc.abstractmethod
def __build(self, base_weights=None) -> tf.keras.models.Model: