# """This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and # MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the # need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the # MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a # non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow # based solution or look into an alternative such as WandB or use a platform solution such as AWS. # """ # import importlib # import json # import os # import warnings # from typing import Mapping # # import numpy as np # from funcy import rcompose # # from image_prediction.exceptions import IncorrectInstantiation # from image_prediction.model_loader.loader import ModelLoader # # warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") # # import mlflow # # # def load_object(object_path): # path_fragments = object_path.split(".") # # module_path = ".".join(path_fragments[:-1]) # object_name = path_fragments[-1] # # module = importlib.import_module(module_path) # return getattr(module, object_name) # # # def to_local_path(uri): # return uri[7:] # # # class MlflowModelReader: # # def __init__(self, run_id, mlruns_dir=None): # mlflow.set_tracking_uri(mlruns_dir) # # self.run_id = run_id # self.run = mlflow.get_run(run_id) # self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir) # # @staticmethod # def __correct_artifact_uri(run_artifact_uri, base_path): # _, suffix = run_artifact_uri.split("mlruns/") # return os.path.join(base_path, suffix) # # def get_weights_path(self, prefix="tt"): # path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5") # return path # # def get_classes(self, prefix="tt"): # classes = json.loads( # self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"') # ) # return classes # # def get_model_handle(self, base_weights=None): # weights_path = self.get_weights_path() # model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip()) # model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights) # model_handle.load_top_weights(weights_path) # return model_handle # # # class PredictionModelHandle: # """Simplifies usage of ModelHandle instances for prediction purposes.""" # # def __init__(self, model_handle, classes_readable: Mapping[int, str]): # self.__model_handle = model_handle # self.__classes_readable = classes_readable # # @property # def classes(self): # return self.__classes_readable # # def predict(self, *args, **kwargs): # predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict) # return predict(*args, **kwargs) # # def predict_proba(self, *args, **kwargs): # predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba) # return predict(*args, **kwargs) # # # class MlflowLoader(ModelLoader): # # def __init__(self, mlruns_dir): # self.__mlruns_dir = mlruns_dir # self._base_weights = None # # def load_model(self, run_id, base_weights=None) -> PredictionModelHandle: # # # TODO: refac https://stackoverflow.com/questions/42735421/how-to-restrict-object-instantiation-only-via-a-factory-in-python # if not base_weights: # # if not self._base_weights: # raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.") # # base_weights = self._base_weights # # mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir) # model_handel = mlflow_reader.get_model_handle(base_weights) # model_handle = model_handel # classes_readable = self.__load_classes(model_handle) # # model = PredictionModelHandle(model_handle, classes_readable) # # return model # # @staticmethod # def __load_classes(model_handle): # # classes = model_handle.model.classes_ # classes_readable = np.array(model_handle.classes) # classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]] # # return classes_readable_aligned