2022-03-29 11:02:43 +02:00

102 lines
3.6 KiB
Python

"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
based solution or look into an alternative such as WandB or use a platform solution such as AWS.
"""
import importlib
import json
import os
import warnings
import numpy as np
from image_prediction.exceptions import IncorrectInstantiation
from image_prediction.model_loader.loader import ModelLoader
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
import mlflow
def load_object(object_path):
path_fragments = object_path.split(".")
module_path = ".".join(path_fragments[:-1])
object_name = path_fragments[-1]
module = importlib.import_module(module_path)
return getattr(module, object_name)
def to_local_path(uri):
return uri[7:]
class MlflowModelReader:
def __init__(self, run_id, mlruns_dir=None):
mlflow.set_tracking_uri(mlruns_dir)
self.run_id = run_id
self.run = mlflow.get_run(run_id)
self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir)
@staticmethod
def __correct_artifact_uri(run_artifact_uri, base_path):
_, suffix = run_artifact_uri.split("mlruns/")
return os.path.join(base_path, suffix)
def get_weights_path(self, prefix="tt"):
path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5")
return path
def get_classes(self, prefix="tt"):
classes = json.loads(
self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"')
)
return classes
def get_model_handle(self, base_weights=None):
weights_path = self.get_weights_path()
model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip())
model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights)
model_handle.load_top_weights(weights_path)
return model_handle
class MlflowLoader(ModelLoader):
def __init__(self, mlruns_dir):
self.__mlruns_dir = mlruns_dir
self._model_handle = None
self.__last_run_id = None
self._base_weights = None
def load_model(self, run_id, base_weights=None):
if not base_weights:
if not self._base_weights:
raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.")
base_weights = self._base_weights
if not self._model_handle and run_id == self.__last_run_id:
mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
model_handel = mlflow_reader.get_model_handle(base_weights)
self._model_handle = model_handel
self.__last_run_id = run_id
return self._model_handle
def load_classes(self, run_id):
model_handle = self.load_model(run_id)
classes = model_handle.model.classes_
classes_readable = np.array(model_handle.classes)
classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]]
return classes_readable_aligned