2022-03-29 20:02:40 +02:00

73 lines
2.2 KiB
Python

import importlib
import json
import os
from functools import lru_cache
import mlflow
from image_prediction.redai_adapter.model import PredictionModelHandle
class MlflowModelReader:
def __init__(self, mlruns_dir=None):
self.mlruns_dir = mlruns_dir
mlflow.set_tracking_uri(self.mlruns_dir)
@staticmethod
def __correct_artifact_uri(run_artifact_uri, base_path):
_, suffix = run_artifact_uri.split("mlruns/")
return os.path.join(base_path, suffix)
def __get_weights_path(self, run_id, prefix="tt"):
run = self.__get_run(run_id)
artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir)
path = os.path.join(artifact_uri, prefix, "train_dev", "estimator")
base_path = os.path.join(path, "base_weights.h5")
weights_path = os.path.join(path, "weights.h5")
return base_path, weights_path
@lru_cache(maxsize=None)
def __get_run(self, run_id):
return mlflow.get_run(run_id)
def __get_classes(self, run_id, prefix="tt"):
run = self.__get_run(run_id)
classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"'))
return classes
def __get_model_handle(self, run_id):
run = self.__get_run(run_id)
model_handle_builder = load_object(run.data.params["model_handle_builder"].strip())
base_weights_path, weights_path = self.__get_weights_path(run_id)
model_handle = model_handle_builder(self.__get_classes(run_id), base_weights=base_weights_path)
model_handle.load_top_weights(weights_path)
return model_handle
def __get_model(self, run_id) -> PredictionModelHandle:
model_handle = self.__get_model_handle(run_id)
model = PredictionModelHandle(model_handle)
return model
def __getitem__(self, run_id):
return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)}
def load_object(object_path):
path_fragments = object_path.split(".")
module_path = ".".join(path_fragments[:-1])
object_name = path_fragments[-1]
module = importlib.import_module(module_path)
return getattr(module, object_name)