adjust caller hierarchy

This commit is contained in:
Julius Unverfehrt 2022-08-30 13:17:21 +02:00
parent 265c61df1a
commit 55a5dd11d6

View File

@ -24,16 +24,12 @@ class ModelWrapper(abc.ABC):
def classes(self): def classes(self):
return self.__classes return self.__classes
@abc.abstractmethod def prep_images(self, images):
def __preprocess_tensor(self, tensor): images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images))
raise NotImplementedError tensor = self.__images_to_tensor(images)
tensor = self.__preprocess_tensor(tensor)
@staticmethod return tensor, valid_mask
def __images_to_tensor(images):
return np.array(list(map(tf.keras.preprocessing.image.img_to_array, images)))
def __resize_and_convert(self, image):
return image.resize(self.input_shape[:-1]).convert("RGB")
def __monitored_resize_and_convert(self, image): def __monitored_resize_and_convert(self, image):
# RED-5170: fails if image is 'broken' # RED-5170: fails if image is 'broken'
@ -46,18 +42,22 @@ class ModelWrapper(abc.ABC):
return image, valid return image, valid
def __resize_and_convert(self, image):
return image.resize(self.input_shape[:-1]).convert("RGB")
def __handle_resize_exception(self, err): def __handle_resize_exception(self, err):
logger.warn(f"{err}: couldn't resize image, replace and passthrough.") logger.warn(f"{err}: couldn't resize image, replace and passthrough.")
image = Image.new("RGB", self.input_shape[:-1]) image = Image.new("RGB", self.input_shape[:-1])
valid = False valid = False
return image, valid return image, valid
def prep_images(self, images): @staticmethod
images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images)) def __images_to_tensor(images):
tensor = self.__images_to_tensor(images) return np.array(list(map(tf.keras.preprocessing.image.img_to_array, images)))
tensor = self.__preprocess_tensor(tensor)
return tensor, valid_mask @abc.abstractmethod
def __preprocess_tensor(self, tensor):
raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def __build(self, base_weights=None) -> tf.keras.models.Model: def __build(self, base_weights=None) -> tf.keras.models.Model: