eliminated redai dependency; updated requirement versions
This commit is contained in:
parent
2b48c6108b
commit
94783c54f2
@ -47,8 +47,9 @@ class MlflowModelReader:
|
||||
|
||||
base_weights_path, weights_path = self.__get_weights_path(run_id)
|
||||
|
||||
model_handle = model_handle_builder(self.__get_classes(run_id), base_weights=base_weights_path)
|
||||
model_handle.load_top_weights(weights_path)
|
||||
model_handle = model_handle_builder(
|
||||
self.__get_classes(run_id), base_weights_path=base_weights_path, weights_path=weights_path
|
||||
)
|
||||
|
||||
return model_handle
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ class PredictionModelHandle:
|
||||
"""Simplifies usage of ModelHandle instances for prediction purposes."""
|
||||
|
||||
def __init__(self, model_handle):
|
||||
self.__predict = rcompose(model_handle.prep_images, model_handle.model.predict_proba)
|
||||
self.__predict = rcompose(model_handle.prep_images, model_handle.model.predict)
|
||||
|
||||
def predict(self, *args, **kwargs):
|
||||
return self.__predict(*args, **kwargs)
|
||||
|
||||
75
image_prediction/redai_adapter/model_wrapper.py
Normal file
75
image_prediction/redai_adapter/model_wrapper.py
Normal file
@ -0,0 +1,75 @@
|
||||
import os
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class EfficientNetWrapper:
|
||||
|
||||
def __init__(self, classes, base_weights_path=None, weights_path=None):
|
||||
self.__classes = classes
|
||||
self.__input_shape = (224, 224, 3)
|
||||
self.model = self.__build(base_weights_path)
|
||||
self.model.load_weights(weights_path)
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return self.__input_shape
|
||||
|
||||
@property
|
||||
def classes(self):
|
||||
return self.__classes
|
||||
|
||||
@staticmethod
|
||||
def __preprocess_tensor(tensor):
|
||||
return tf.keras.applications.efficientnet.preprocess_input(tensor)
|
||||
|
||||
@staticmethod
|
||||
def __images_to_tensor(images):
|
||||
return np.array(list(map(tf.keras.preprocessing.image.img_to_array, images)))
|
||||
|
||||
def __resize_and_convert(self, image):
|
||||
return image.resize(self.input_shape[:-1]).convert("RGB")
|
||||
|
||||
def prep_images(self, images):
|
||||
images = map(self.__resize_and_convert, images)
|
||||
tensor = self.__images_to_tensor(images)
|
||||
tensor = self.__preprocess_tensor(tensor)
|
||||
|
||||
return tensor
|
||||
|
||||
def __build(self, base_weights=None) -> tf.keras.models.Model:
|
||||
input_img = tf.keras.layers.Input(shape=self.input_shape)
|
||||
|
||||
pretrained = tf.keras.applications.efficientnet.EfficientNetB0(
|
||||
include_top=False, input_tensor=tf.keras.layers.Input(shape=self.input_shape), weights=base_weights
|
||||
)
|
||||
|
||||
pretrained.trainable = False
|
||||
|
||||
for layer in pretrained.layers:
|
||||
layer.trainable = False
|
||||
|
||||
pretrained = pretrained(input_img)
|
||||
|
||||
finetuned = tf.keras.layers.Flatten()(pretrained)
|
||||
finetuned = tf.keras.layers.Dense(512, activation="relu")(finetuned)
|
||||
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
|
||||
finetuned = tf.keras.layers.Dense(128, activation="relu")(finetuned)
|
||||
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
|
||||
finetuned = tf.keras.layers.Dense(32, activation="relu")(finetuned)
|
||||
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
|
||||
finetuned = tf.keras.layers.Dense(len(self.classes), activation="softmax")(finetuned)
|
||||
|
||||
model = tf.keras.models.Model(inputs=input_img, outputs=finetuned)
|
||||
|
||||
model.compile(
|
||||
loss="categorical_crossentropy",
|
||||
optimizer="adam",
|
||||
metrics=[tf.keras.metrics.Recall(), tf.keras.metrics.Precision()],
|
||||
)
|
||||
|
||||
return model
|
||||
@ -1,5 +1,5 @@
|
||||
[pytest]
|
||||
norecursedirs = incl
|
||||
filterwarnings =
|
||||
ignore:.*imp.*:DeprecationWarning
|
||||
ignore:.*Use.*:DeprecationWarning
|
||||
ignore:.*:DeprecationWarning
|
||||
ignore:.*:DeprecationWarning
|
||||
|
||||
@ -1,26 +1,23 @@
|
||||
Flask==2.0.2
|
||||
Flask==2.1.1
|
||||
requests==2.27.1
|
||||
iteration-utilities==0.11.0
|
||||
dvc==2.9.3
|
||||
dvc==2.10.0
|
||||
dvc[ssh]
|
||||
frozendict==2.3.0
|
||||
waitress==2.0.0
|
||||
envyaml~=1.8.210417
|
||||
waitress==2.1.1
|
||||
envyaml==1.10.211231
|
||||
dependency-check==0.6.*
|
||||
envyaml~=1.8.210417
|
||||
mlflow~=1.20.2
|
||||
numpy~=1.19.3
|
||||
PDFNetPython3~=9.1.0
|
||||
tqdm~=4.62.2
|
||||
pandas~=1.3.1
|
||||
mlflow~=1.20.2
|
||||
tensorflow~=2.5.0
|
||||
PDFNetPython3~=9.1.0
|
||||
Pillow~=8.3.2
|
||||
mlflow==1.24.0
|
||||
numpy==1.22.3
|
||||
tqdm==4.63.1
|
||||
pandas==1.4.1
|
||||
mlflow==1.24.0
|
||||
tensorflow==2.8.0
|
||||
PyYAML~=5.4.1
|
||||
scikit_learn~=0.24.2
|
||||
scikit_learn==1.0.2
|
||||
pytest~=7.1.0
|
||||
funcy==1.17
|
||||
PyMuPDF==1.19.6
|
||||
fpdf==1.7.2
|
||||
coverage==6.3.2
|
||||
Pillow==9.1.0
|
||||
|
||||
@ -63,7 +63,7 @@ def server_ready(url):
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def server_process(server, host_and_port, url):
|
||||
def get_server_process():
|
||||
coverage.process_startup()
|
||||
# coverage.process_startup()
|
||||
return Process(target=run_prediction_server, kwargs={"app": server, **host_and_port})
|
||||
|
||||
server = get_server_process()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user