2022-03-29 17:25:06 +02:00

124 lines
4.6 KiB
Python

# """This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
# MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
# need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
# MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
# non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
# based solution or look into an alternative such as WandB or use a platform solution such as AWS.
# """
# import importlib
# import json
# import os
# import warnings
# from typing import Mapping
#
# import numpy as np
# from funcy import rcompose
#
# from image_prediction.exceptions import IncorrectInstantiation
# from image_prediction.model_loader.loader import ModelLoader
#
# warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
#
# import mlflow
#
#
# def load_object(object_path):
# path_fragments = object_path.split(".")
#
# module_path = ".".join(path_fragments[:-1])
# object_name = path_fragments[-1]
#
# module = importlib.import_module(module_path)
# return getattr(module, object_name)
#
#
# def to_local_path(uri):
# return uri[7:]
#
#
# class MlflowModelReader:
#
# def __init__(self, run_id, mlruns_dir=None):
# mlflow.set_tracking_uri(mlruns_dir)
#
# self.run_id = run_id
# self.run = mlflow.get_run(run_id)
# self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir)
#
# @staticmethod
# def __correct_artifact_uri(run_artifact_uri, base_path):
# _, suffix = run_artifact_uri.split("mlruns/")
# return os.path.join(base_path, suffix)
#
# def get_weights_path(self, prefix="tt"):
# path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5")
# return path
#
# def get_classes(self, prefix="tt"):
# classes = json.loads(
# self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"')
# )
# return classes
#
# def get_model_handle(self, base_weights=None):
# weights_path = self.get_weights_path()
# model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip())
# model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights)
# model_handle.load_top_weights(weights_path)
# return model_handle
#
#
# class PredictionModelHandle:
# """Simplifies usage of ModelHandle instances for prediction purposes."""
#
# def __init__(self, model_handle, classes_readable: Mapping[int, str]):
# self.__model_handle = model_handle
# self.__classes_readable = classes_readable
#
# @property
# def classes(self):
# return self.__classes_readable
#
# def predict(self, *args, **kwargs):
# predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
# return predict(*args, **kwargs)
#
# def predict_proba(self, *args, **kwargs):
# predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
# return predict(*args, **kwargs)
#
#
# class MlflowLoader(ModelLoader):
#
# def __init__(self, mlruns_dir):
# self.__mlruns_dir = mlruns_dir
# self._base_weights = None
#
# def load_model(self, run_id, base_weights=None) -> PredictionModelHandle:
#
# # TODO: refac https://stackoverflow.com/questions/42735421/how-to-restrict-object-instantiation-only-via-a-factory-in-python
# if not base_weights:
#
# if not self._base_weights:
# raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.")
#
# base_weights = self._base_weights
#
# mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
# model_handel = mlflow_reader.get_model_handle(base_weights)
# model_handle = model_handel
# classes_readable = self.__load_classes(model_handle)
#
# model = PredictionModelHandle(model_handle, classes_readable)
#
# return model
#
# @staticmethod
# def __load_classes(model_handle):
#
# classes = model_handle.model.classes_
# classes_readable = np.array(model_handle.classes)
# classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]]
#
# return classes_readable_aligned