import os os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import numpy as np import tensorflow as tf class EfficientNetWrapper: def __init__(self, classes, base_weights_path=None, weights_path=None): self.__classes = classes self.__input_shape = (224, 224, 3) self.model = self.__build(base_weights_path) self.model.load_weights(weights_path) @property def input_shape(self): return self.__input_shape @property def classes(self): return self.__classes @staticmethod def __preprocess_tensor(tensor): return tf.keras.applications.efficientnet.preprocess_input(tensor) @staticmethod def __images_to_tensor(images): return np.array(list(map(tf.keras.preprocessing.image.img_to_array, images))) def __resize_and_convert(self, image): return image.resize(self.input_shape[:-1]).convert("RGB") def prep_images(self, images): images = map(self.__resize_and_convert, images) tensor = self.__images_to_tensor(images) tensor = self.__preprocess_tensor(tensor) return tensor def __build(self, base_weights=None) -> tf.keras.models.Model: input_img = tf.keras.layers.Input(shape=self.input_shape) pretrained = tf.keras.applications.efficientnet.EfficientNetB0( include_top=False, input_tensor=tf.keras.layers.Input(shape=self.input_shape), weights=base_weights ) pretrained.trainable = False for layer in pretrained.layers: layer.trainable = False pretrained = pretrained(input_img) finetuned = tf.keras.layers.Flatten()(pretrained) finetuned = tf.keras.layers.Dense(512, activation="relu")(finetuned) finetuned = tf.keras.layers.Dropout(0.2)(finetuned) finetuned = tf.keras.layers.Dense(128, activation="relu")(finetuned) finetuned = tf.keras.layers.Dropout(0.2)(finetuned) finetuned = tf.keras.layers.Dense(32, activation="relu")(finetuned) finetuned = tf.keras.layers.Dropout(0.2)(finetuned) finetuned = tf.keras.layers.Dense(len(self.classes), activation="softmax")(finetuned) model = tf.keras.models.Model(inputs=input_img, outputs=finetuned) model.compile( loss="categorical_crossentropy", optimizer="adam", metrics=[tf.keras.metrics.Recall(), tf.keras.metrics.Precision()], ) return model