111 lines
3.7 KiB
Python
111 lines
3.7 KiB
Python
"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
|
|
MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
|
|
need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
|
|
MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
|
|
non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
|
|
based solution or look into an alternative such as WandB or use a platform solution such as AWS.
|
|
"""
|
|
|
|
import os
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
|
import importlib
|
|
import json
|
|
import os
|
|
from functools import lru_cache
|
|
|
|
from funcy import rcompose
|
|
|
|
from image_prediction.model_loader.database.connector import DatabaseConnector
|
|
|
|
import mlflow
|
|
|
|
|
|
class PredictionModelHandle:
|
|
"""Simplifies usage of ModelHandle instances for prediction purposes."""
|
|
|
|
def __init__(self, model_handle):
|
|
self.__model_handle = model_handle
|
|
|
|
def predict(self, *args, **kwargs):
|
|
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
|
|
return predict(*args, **kwargs)
|
|
|
|
def predict_proba(self, *args, **kwargs):
|
|
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
|
|
return predict(*args, **kwargs)
|
|
|
|
|
|
class MlflowModelReader:
|
|
|
|
def __init__(self, mlruns_dir=None):
|
|
self.mlruns_dir = mlruns_dir
|
|
mlflow.set_tracking_uri(self.mlruns_dir)
|
|
|
|
@staticmethod
|
|
def __correct_artifact_uri(run_artifact_uri, base_path):
|
|
_, suffix = run_artifact_uri.split("mlruns/")
|
|
return os.path.join(base_path, suffix)
|
|
|
|
def __get_weights_path(self, run_id, prefix="tt"):
|
|
run = self.__get_run(run_id)
|
|
|
|
artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir)
|
|
path = os.path.join(artifact_uri, prefix, "train_dev", "estimator")
|
|
|
|
base_path = os.path.join(path, "base_weights.h5")
|
|
weights_path = os.path.join(path, "weights.h5")
|
|
|
|
return base_path, weights_path
|
|
|
|
@lru_cache(maxsize=None)
|
|
def __get_run(self, run_id):
|
|
return mlflow.get_run(run_id)
|
|
|
|
def __get_classes(self, run_id, prefix="tt"):
|
|
run = self.__get_run(run_id)
|
|
|
|
classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"'))
|
|
|
|
return classes
|
|
|
|
def __get_model_handle(self, run_id):
|
|
run = self.__get_run(run_id)
|
|
|
|
model_handle_builder = load_object(run.data.params["model_handle_builder"].strip())
|
|
|
|
base_weights_path, weights_path = self.__get_weights_path(run_id)
|
|
|
|
model_handle = model_handle_builder(self.__get_classes(run_id), base_weights=base_weights_path)
|
|
model_handle.load_top_weights(weights_path)
|
|
|
|
return model_handle
|
|
|
|
def __get_model(self, run_id) -> PredictionModelHandle:
|
|
model_handle = self.__get_model_handle(run_id)
|
|
model = PredictionModelHandle(model_handle)
|
|
return model
|
|
|
|
def __getitem__(self, run_id):
|
|
return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)}
|
|
|
|
|
|
def load_object(object_path):
|
|
path_fragments = object_path.split(".")
|
|
|
|
module_path = ".".join(path_fragments[:-1])
|
|
object_name = path_fragments[-1]
|
|
|
|
module = importlib.import_module(module_path)
|
|
return getattr(module, object_name)
|
|
|
|
|
|
class MlflowConnector(DatabaseConnector):
|
|
|
|
def __init__(self, mlflow_reader: MlflowModelReader):
|
|
self.mlflow_reader = mlflow_reader
|
|
|
|
def get_object(self, run_id):
|
|
return self.mlflow_reader[run_id]
|