"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow based solution or look into an alternative such as WandB or use a platform solution such as AWS. """ import os os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import importlib import json import os from functools import lru_cache from funcy import rcompose from image_prediction.model_loader.database.connector import DatabaseConnector import mlflow class PredictionModelHandle: """Simplifies usage of ModelHandle instances for prediction purposes.""" def __init__(self, model_handle): self.__model_handle = model_handle def predict(self, *args, **kwargs): predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict) return predict(*args, **kwargs) def predict_proba(self, *args, **kwargs): predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba) return predict(*args, **kwargs) class MlflowModelReader: def __init__(self, mlruns_dir=None): self.mlruns_dir = mlruns_dir mlflow.set_tracking_uri(self.mlruns_dir) @staticmethod def __correct_artifact_uri(run_artifact_uri, base_path): _, suffix = run_artifact_uri.split("mlruns/") return os.path.join(base_path, suffix) def __get_weights_path(self, run_id, prefix="tt"): run = self.__get_run(run_id) artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir) path = os.path.join(artifact_uri, prefix, "train_dev", "estimator") base_path = os.path.join(path, "base_weights.h5") weights_path = os.path.join(path, "weights.h5") return base_path, weights_path @lru_cache(maxsize=None) def __get_run(self, run_id): return mlflow.get_run(run_id) def __get_classes(self, run_id, prefix="tt"): run = self.__get_run(run_id) classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"')) return classes def __get_model_handle(self, run_id): run = self.__get_run(run_id) model_handle_builder = load_object(run.data.params["model_handle_builder"].strip()) base_weights_path, weights_path = self.__get_weights_path(run_id) model_handle = model_handle_builder(self.__get_classes(run_id), base_weights=base_weights_path) model_handle.load_top_weights(weights_path) return model_handle def __get_model(self, run_id) -> PredictionModelHandle: model_handle = self.__get_model_handle(run_id) model = PredictionModelHandle(model_handle) return model def __getitem__(self, run_id): return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)} def load_object(object_path): path_fragments = object_path.split(".") module_path = ".".join(path_fragments[:-1]) object_name = path_fragments[-1] module = importlib.import_module(module_path) return getattr(module, object_name) class MlflowConnector(DatabaseConnector): def __init__(self, mlflow_reader: MlflowModelReader): self.mlflow_reader = mlflow_reader def get_object(self, run_id): return self.mlflow_reader[run_id]