102 lines
3.6 KiB
Python
102 lines
3.6 KiB
Python
"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
|
|
MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
|
|
need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
|
|
MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
|
|
non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
|
|
based solution or look into an alternative such as WandB or use a platform solution such as AWS.
|
|
"""
|
|
import importlib
|
|
import json
|
|
import os
|
|
import warnings
|
|
|
|
import numpy as np
|
|
|
|
from image_prediction.exceptions import IncorrectInstantiation
|
|
from image_prediction.model_loader.loader import ModelLoader
|
|
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
|
|
|
|
import mlflow
|
|
|
|
|
|
def load_object(object_path):
|
|
path_fragments = object_path.split(".")
|
|
|
|
module_path = ".".join(path_fragments[:-1])
|
|
object_name = path_fragments[-1]
|
|
|
|
module = importlib.import_module(module_path)
|
|
return getattr(module, object_name)
|
|
|
|
|
|
def to_local_path(uri):
|
|
return uri[7:]
|
|
|
|
|
|
class MlflowModelReader:
|
|
|
|
def __init__(self, run_id, mlruns_dir=None):
|
|
mlflow.set_tracking_uri(mlruns_dir)
|
|
|
|
self.run_id = run_id
|
|
self.run = mlflow.get_run(run_id)
|
|
self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir)
|
|
|
|
@staticmethod
|
|
def __correct_artifact_uri(run_artifact_uri, base_path):
|
|
_, suffix = run_artifact_uri.split("mlruns/")
|
|
return os.path.join(base_path, suffix)
|
|
|
|
def get_weights_path(self, prefix="tt"):
|
|
path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5")
|
|
return path
|
|
|
|
def get_classes(self, prefix="tt"):
|
|
classes = json.loads(
|
|
self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"')
|
|
)
|
|
return classes
|
|
|
|
def get_model_handle(self, base_weights=None):
|
|
weights_path = self.get_weights_path()
|
|
model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip())
|
|
model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights)
|
|
model_handle.load_top_weights(weights_path)
|
|
return model_handle
|
|
|
|
|
|
class MlflowLoader(ModelLoader):
|
|
|
|
def __init__(self, mlruns_dir):
|
|
self.__mlruns_dir = mlruns_dir
|
|
self._model_handle = None
|
|
self.__last_run_id = None
|
|
self._base_weights = None
|
|
|
|
def load_model(self, run_id, base_weights=None):
|
|
|
|
if not base_weights:
|
|
|
|
if not self._base_weights:
|
|
raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.")
|
|
|
|
base_weights = self._base_weights
|
|
|
|
if not self._model_handle and run_id == self.__last_run_id:
|
|
mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
|
|
model_handel = mlflow_reader.get_model_handle(base_weights)
|
|
self._model_handle = model_handel
|
|
self.__last_run_id = run_id
|
|
|
|
return self._model_handle
|
|
|
|
def load_classes(self, run_id):
|
|
model_handle = self.load_model(run_id)
|
|
|
|
classes = model_handle.model.classes_
|
|
classes_readable = np.array(model_handle.classes)
|
|
classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]]
|
|
|
|
return classes_readable_aligned
|