import importlib import json import os from functools import lru_cache import mlflow from image_prediction.redai_adapter.model import PredictionModelHandle class MlflowModelReader: def __init__(self, mlruns_dir=None): self.mlruns_dir = mlruns_dir mlflow.set_tracking_uri(self.mlruns_dir) @staticmethod def __correct_artifact_uri(run_artifact_uri, base_path): _, suffix = run_artifact_uri.split("mlruns/") return os.path.join(base_path, suffix) def __get_weights_path(self, run_id, prefix="tt"): run = self.__get_run(run_id) artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir) path = os.path.join(artifact_uri, prefix, "train_dev", "estimator") base_path = os.path.join(path, "base_weights.h5") weights_path = os.path.join(path, "weights.h5") return base_path, weights_path @lru_cache(maxsize=None) def __get_run(self, run_id): return mlflow.get_run(run_id) def __get_classes(self, run_id, prefix="tt"): run = self.__get_run(run_id) classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"')) return classes def __get_model_handle(self, run_id): run = self.__get_run(run_id) model_handle_builder = load_object(run.data.params["model_handle_builder"].strip()) base_weights_path, weights_path = self.__get_weights_path(run_id) model_handle = model_handle_builder( self.__get_classes(run_id), base_weights_path=base_weights_path, weights_path=weights_path ) return model_handle def __get_model(self, run_id) -> PredictionModelHandle: model_handle = self.__get_model_handle(run_id) model = PredictionModelHandle(model_handle) return model def __getitem__(self, run_id): return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)} def load_object(object_path): path_fragments = object_path.split(".") module_path = ".".join(path_fragments[:-1]) object_name = path_fragments[-1] module = importlib.import_module(module_path) return getattr(module, object_name)