From 55a5dd11d65a29c7838e62bbb04eaf4b8ecd7840 Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Tue, 30 Aug 2022 13:17:21 +0200 Subject: [PATCH] adjust caller hierarchy --- .../redai_adapter/model_wrapper.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/image_prediction/redai_adapter/model_wrapper.py b/image_prediction/redai_adapter/model_wrapper.py index de9d543..5ad93a9 100644 --- a/image_prediction/redai_adapter/model_wrapper.py +++ b/image_prediction/redai_adapter/model_wrapper.py @@ -24,16 +24,12 @@ class ModelWrapper(abc.ABC): def classes(self): return self.__classes - @abc.abstractmethod - def __preprocess_tensor(self, tensor): - raise NotImplementedError + def prep_images(self, images): + images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images)) + tensor = self.__images_to_tensor(images) + tensor = self.__preprocess_tensor(tensor) - @staticmethod - def __images_to_tensor(images): - return np.array(list(map(tf.keras.preprocessing.image.img_to_array, images))) - - def __resize_and_convert(self, image): - return image.resize(self.input_shape[:-1]).convert("RGB") + return tensor, valid_mask def __monitored_resize_and_convert(self, image): # RED-5170: fails if image is 'broken' @@ -46,18 +42,22 @@ class ModelWrapper(abc.ABC): return image, valid + def __resize_and_convert(self, image): + return image.resize(self.input_shape[:-1]).convert("RGB") + def __handle_resize_exception(self, err): logger.warn(f"{err}: couldn't resize image, replace and passthrough.") image = Image.new("RGB", self.input_shape[:-1]) valid = False return image, valid - def prep_images(self, images): - images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images)) - tensor = self.__images_to_tensor(images) - tensor = self.__preprocess_tensor(tensor) + @staticmethod + def __images_to_tensor(images): + return np.array(list(map(tf.keras.preprocessing.image.img_to_array, images))) - return tensor, valid_mask + @abc.abstractmethod + def __preprocess_tensor(self, tensor): + raise NotImplementedError @abc.abstractmethod def __build(self, base_weights=None) -> tf.keras.models.Model: