import abc import numpy as np import tensorflow as tf class ModelWrapper(abc.ABC): def __init__(self, classes, base_weights_path=None, weights_path=None): self.__classes = classes self.model = self.__build(base_weights_path) self.model.load_weights(weights_path) @property @abc.abstractmethod def input_shape(self): raise NotImplementedError @property def classes(self): return self.__classes @abc.abstractmethod def __preprocess_tensor(self, tensor): raise NotImplementedError @staticmethod def __images_to_tensor(images): return np.array(list(map(tf.keras.preprocessing.image.img_to_array, images))) def __resize_and_convert(self, image): return image.resize(self.input_shape[:-1]).convert("RGB") def prep_images(self, images): images = map(self.__resize_and_convert, images) tensor = self.__images_to_tensor(images) tensor = self.__preprocess_tensor(tensor) return tensor @abc.abstractmethod def __build(self, base_weights=None) -> tf.keras.models.Model: raise NotImplementedError