Matthias Bisping ab382646b7 applied black
2022-04-03 04:47:49 +02:00

46 lines
1.7 KiB
Python

import tensorflow as tf
from image_prediction.redai_adapter.model_wrapper import ModelWrapper
class EfficientNetWrapper(ModelWrapper):
def __init__(self, classes, base_weights_path=None, weights_path=None):
self.__input_shape = (224, 224, 3)
super().__init__(classes=classes, base_weights_path=base_weights_path, weights_path=weights_path)
@property
def input_shape(self):
return self.__input_shape
def _ModelWrapper__preprocess_tensor(self, tensor):
return tf.keras.applications.efficientnet.preprocess_input(tensor)
def _ModelWrapper__build(self, base_weights=None) -> tf.keras.models.Model:
input_img = tf.keras.layers.Input(shape=self.input_shape)
pretrained = tf.keras.applications.efficientnet.EfficientNetB0(
include_top=False, input_tensor=tf.keras.layers.Input(shape=self.input_shape), weights=base_weights
)
pretrained.trainable = False
for layer in pretrained.layers:
layer.trainable = False
pretrained = pretrained(input_img)
finetuned = tf.keras.layers.Flatten()(pretrained)
finetuned = tf.keras.layers.Dense(512, activation="relu")(finetuned)
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
finetuned = tf.keras.layers.Dense(128, activation="relu")(finetuned)
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
finetuned = tf.keras.layers.Dense(32, activation="relu")(finetuned)
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
finetuned = tf.keras.layers.Dense(len(self.classes), activation="softmax")(finetuned)
model = tf.keras.models.Model(inputs=input_img, outputs=finetuned)
model.compile()
return model