72 lines
2.2 KiB
Python
72 lines
2.2 KiB
Python
import importlib
|
|
import json
|
|
import os
|
|
from functools import lru_cache
|
|
|
|
import mlflow
|
|
|
|
from image_prediction.redai_adapter.model import PredictionModelHandle
|
|
|
|
|
|
class MlflowModelReader:
|
|
def __init__(self, mlruns_dir=None):
|
|
self.mlruns_dir = mlruns_dir
|
|
mlflow.set_tracking_uri(self.mlruns_dir)
|
|
|
|
@staticmethod
|
|
def __correct_artifact_uri(run_artifact_uri, base_path):
|
|
_, suffix = run_artifact_uri.split("mlruns/")
|
|
return os.path.join(base_path, suffix)
|
|
|
|
def __get_weights_path(self, run_id, prefix="tt"):
|
|
run = self.__get_run(run_id)
|
|
|
|
artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir)
|
|
path = os.path.join(artifact_uri, prefix, "train_dev", "estimator")
|
|
|
|
base_path = os.path.join(path, "base_weights.h5")
|
|
weights_path = os.path.join(path, "weights.h5")
|
|
|
|
return base_path, weights_path
|
|
|
|
@lru_cache(maxsize=None)
|
|
def __get_run(self, run_id):
|
|
return mlflow.get_run(run_id)
|
|
|
|
def __get_classes(self, run_id, prefix="tt"):
|
|
run = self.__get_run(run_id)
|
|
|
|
classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"'))
|
|
|
|
return classes
|
|
|
|
def __get_model_handle(self, run_id):
|
|
run = self.__get_run(run_id)
|
|
|
|
model_handle_builder = load_object(run.data.params["model_handle_builder"].strip())
|
|
|
|
base_weights_path, weights_path = self.__get_weights_path(run_id)
|
|
|
|
model_handle = model_handle_builder(self.__get_classes(run_id), base_weights=base_weights_path)
|
|
model_handle.load_top_weights(weights_path)
|
|
|
|
return model_handle
|
|
|
|
def __get_model(self, run_id) -> PredictionModelHandle:
|
|
model_handle = self.__get_model_handle(run_id)
|
|
model = PredictionModelHandle(model_handle)
|
|
return model
|
|
|
|
def __getitem__(self, run_id):
|
|
return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)}
|
|
|
|
|
|
def load_object(object_path):
|
|
path_fragments = object_path.split(".")
|
|
|
|
module_path = ".".join(path_fragments[:-1])
|
|
object_name = path_fragments[-1]
|
|
|
|
module = importlib.import_module(module_path)
|
|
return getattr(module, object_name)
|