refactoring: model wrapper to base class and derived class for efficient net
This commit is contained in:
parent
070749880e
commit
c80549d5d3
50
image_prediction/redai_adapter/efficient_net_wrapper.py
Normal file
50
image_prediction/redai_adapter/efficient_net_wrapper.py
Normal file
@ -0,0 +1,50 @@
|
||||
import os
|
||||
|
||||
from image_prediction.redai_adapter.model_wrapper import ModelWrapper
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class EfficientNetWrapper(ModelWrapper):
|
||||
|
||||
def __init__(self, classes, base_weights_path=None, weights_path=None):
|
||||
self.__input_shape = (224, 224, 3)
|
||||
super().__init__(classes=classes, base_weights_path=base_weights_path, weights_path=weights_path)
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return self.__input_shape
|
||||
|
||||
def _ModelWrapper__preprocess_tensor(self, tensor):
|
||||
return tf.keras.applications.efficientnet.preprocess_input(tensor)
|
||||
|
||||
def _ModelWrapper__build(self, base_weights=None) -> tf.keras.models.Model:
|
||||
input_img = tf.keras.layers.Input(shape=self.input_shape)
|
||||
|
||||
pretrained = tf.keras.applications.efficientnet.EfficientNetB0(
|
||||
include_top=False, input_tensor=tf.keras.layers.Input(shape=self.input_shape), weights=base_weights
|
||||
)
|
||||
|
||||
pretrained.trainable = False
|
||||
|
||||
for layer in pretrained.layers:
|
||||
layer.trainable = False
|
||||
|
||||
pretrained = pretrained(input_img)
|
||||
|
||||
finetuned = tf.keras.layers.Flatten()(pretrained)
|
||||
finetuned = tf.keras.layers.Dense(512, activation="relu")(finetuned)
|
||||
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
|
||||
finetuned = tf.keras.layers.Dense(128, activation="relu")(finetuned)
|
||||
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
|
||||
finetuned = tf.keras.layers.Dense(32, activation="relu")(finetuned)
|
||||
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
|
||||
finetuned = tf.keras.layers.Dense(len(self.classes), activation="softmax")(finetuned)
|
||||
|
||||
model = tf.keras.models.Model(inputs=input_img, outputs=finetuned)
|
||||
|
||||
model.compile()
|
||||
|
||||
return model
|
||||
@ -31,6 +31,7 @@ class MlflowModelReader:
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def __get_run(self, run_id):
|
||||
|
||||
return mlflow.get_run(run_id)
|
||||
|
||||
def __get_classes(self, run_id, prefix="tt"):
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import abc
|
||||
import os
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
@ -7,25 +8,25 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class EfficientNetWrapper:
|
||||
class ModelWrapper(abc.ABC):
|
||||
|
||||
def __init__(self, classes, base_weights_path=None, weights_path=None):
|
||||
self.__classes = classes
|
||||
self.__input_shape = (224, 224, 3)
|
||||
self.model = self.__build(base_weights_path)
|
||||
self.model.load_weights(weights_path)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def input_shape(self):
|
||||
return self.__input_shape
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def classes(self):
|
||||
return self.__classes
|
||||
|
||||
@staticmethod
|
||||
def __preprocess_tensor(tensor):
|
||||
return tf.keras.applications.efficientnet.preprocess_input(tensor)
|
||||
@abc.abstractmethod
|
||||
def __preprocess_tensor(self, tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def __images_to_tensor(images):
|
||||
@ -41,31 +42,6 @@ class EfficientNetWrapper:
|
||||
|
||||
return tensor
|
||||
|
||||
@abc.abstractmethod
|
||||
def __build(self, base_weights=None) -> tf.keras.models.Model:
|
||||
input_img = tf.keras.layers.Input(shape=self.input_shape)
|
||||
|
||||
pretrained = tf.keras.applications.efficientnet.EfficientNetB0(
|
||||
include_top=False, input_tensor=tf.keras.layers.Input(shape=self.input_shape), weights=base_weights
|
||||
)
|
||||
|
||||
pretrained.trainable = False
|
||||
|
||||
for layer in pretrained.layers:
|
||||
layer.trainable = False
|
||||
|
||||
pretrained = pretrained(input_img)
|
||||
|
||||
finetuned = tf.keras.layers.Flatten()(pretrained)
|
||||
finetuned = tf.keras.layers.Dense(512, activation="relu")(finetuned)
|
||||
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
|
||||
finetuned = tf.keras.layers.Dense(128, activation="relu")(finetuned)
|
||||
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
|
||||
finetuned = tf.keras.layers.Dense(32, activation="relu")(finetuned)
|
||||
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
|
||||
finetuned = tf.keras.layers.Dense(len(self.classes), activation="softmax")(finetuned)
|
||||
|
||||
model = tf.keras.models.Model(inputs=input_img, outputs=finetuned)
|
||||
|
||||
model.compile()
|
||||
|
||||
return model
|
||||
raise NotImplementedError
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user