63 lines
1.8 KiB
Python
63 lines
1.8 KiB
Python
import abc
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from PIL import Image
|
|
|
|
from image_prediction.utils import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class ModelWrapper(abc.ABC):
|
|
def __init__(self, classes, base_weights_path=None, weights_path=None):
|
|
self.__classes = classes
|
|
self.model = self.__build(base_weights_path)
|
|
self.model.load_weights(weights_path)
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def input_shape(self):
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def classes(self):
|
|
return self.__classes
|
|
|
|
def prep_images(self, images):
|
|
images, valid_mask = zip(*map(self.__monitored_resize_and_convert, images))
|
|
tensor = self.__images_to_tensor(images)
|
|
tensor = self.__preprocess_tensor(tensor)
|
|
|
|
return tensor, valid_mask
|
|
|
|
def __monitored_resize_and_convert(self, image):
|
|
# RED-5170: fails if image is 'broken'
|
|
try:
|
|
image, valid = self.__resize_and_convert(image), True
|
|
except (OSError, Exception) as err:
|
|
image, valid = self.__handle_resize_exception(err)
|
|
|
|
return image, valid
|
|
|
|
def __resize_and_convert(self, image):
|
|
return image.resize(self.input_shape[:-1]).convert("RGB")
|
|
|
|
def __handle_resize_exception(self, err):
|
|
logger.warn(f"{err}: Couldn't resize image, replace with blank image and passthrough.")
|
|
image = Image.new("RGB", self.input_shape[:-1])
|
|
valid = False
|
|
return image, valid
|
|
|
|
@staticmethod
|
|
def __images_to_tensor(images):
|
|
return np.array(list(map(tf.keras.preprocessing.image.img_to_array, images)))
|
|
|
|
@abc.abstractmethod
|
|
def __preprocess_tensor(self, tensor):
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def __build(self, base_weights=None) -> tf.keras.models.Model:
|
|
raise NotImplementedError
|