2022-03-28 21:51:21 +02:00

84 lines
2.5 KiB
Python

import importlib
import json
import os
import warnings
import numpy as np
from image_prediction.locations import BASE_WEIGHTS
from image_prediction.model_loader.loader import ModelLoader
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
import mlflow
def load_object(object_path):
path_fragments = object_path.split(".")
module_path = ".".join(path_fragments[:-1])
object_name = path_fragments[-1]
module = importlib.import_module(module_path)
return getattr(module, object_name)
def to_local_path(uri):
return uri[7:]
class MlflowModelReader:
def __init__(self, run_id, mlruns_dir=None):
mlflow.set_tracking_uri(mlruns_dir)
self.run_id = run_id
self.run = mlflow.get_run(run_id)
self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir)
@staticmethod
def __correct_artifact_uri(run_artifact_uri, base_path):
_, suffix = run_artifact_uri.split("mlruns/")
return os.path.join(base_path, suffix)
def get_weights_path(self, prefix="tt"):
path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5")
return path
def get_classes(self, prefix="tt"):
classes = json.loads(
self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"')
)
return classes
def get_model_handle(self, base_weights=None):
weights_path = self.get_weights_path()
model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip())
model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights)
model_handle.load_top_weights(weights_path)
return model_handle
class MlflowLoader(ModelLoader):
def __init__(self, mlruns_dir):
self.__mlruns_dir = mlruns_dir
self._model_handle = None
def load_model(self, run_id):
if not self._model_handle:
mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
model_handel = mlflow_reader.get_model_handle(BASE_WEIGHTS)
self._model_handle = model_handel
return self._model_handle
def load_classes(self, run_id):
model_handle = self.load_model(run_id)
classes = model_handle.model.classes_
classes_readable = np.array(model_handle.classes)
classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]]
return classes_readable_aligned