import importlib import json import os import warnings import numpy as np from image_prediction.locations import BASE_WEIGHTS from image_prediction.model_loader.loader import ModelLoader warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") import mlflow def load_object(object_path): path_fragments = object_path.split(".") module_path = ".".join(path_fragments[:-1]) object_name = path_fragments[-1] module = importlib.import_module(module_path) return getattr(module, object_name) def to_local_path(uri): return uri[7:] class MlflowModelReader: def __init__(self, run_id, mlruns_dir=None): mlflow.set_tracking_uri(mlruns_dir) self.run_id = run_id self.run = mlflow.get_run(run_id) self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir) @staticmethod def __correct_artifact_uri(run_artifact_uri, base_path): _, suffix = run_artifact_uri.split("mlruns/") return os.path.join(base_path, suffix) def get_weights_path(self, prefix="tt"): path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5") return path def get_classes(self, prefix="tt"): classes = json.loads( self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"') ) return classes def get_model_handle(self, base_weights=None): weights_path = self.get_weights_path() model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip()) model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights) model_handle.load_top_weights(weights_path) return model_handle class MlflowLoader(ModelLoader): def __init__(self, mlruns_dir): self.__mlruns_dir = mlruns_dir self.__model_handle = None def load_model(self, run_id): if not self.__model_handle: mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir) model_handel = mlflow_reader.get_model_handle(BASE_WEIGHTS) self.__model_handle = model_handel return self.__model_handle def load_classes(self, run_id): model_handle = self.load_model(run_id) classes = model_handle.model.classes_ classes_readable = np.array(model_handle.classes) classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]] return classes_readable_aligned