restructuring of modules
This commit is contained in:
parent
d33a882d65
commit
358d7ecd91
@ -1,104 +1,5 @@
|
||||
"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
|
||||
MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
|
||||
need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
|
||||
MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
|
||||
non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
|
||||
based solution or look into an alternative such as WandB or use a platform solution such as AWS.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
from funcy import rcompose
|
||||
|
||||
from image_prediction.model_loader.database.connector import DatabaseConnector
|
||||
|
||||
import mlflow
|
||||
|
||||
|
||||
class PredictionModelHandle:
|
||||
"""Simplifies usage of ModelHandle instances for prediction purposes."""
|
||||
|
||||
def __init__(self, model_handle):
|
||||
self.__model_handle = model_handle
|
||||
|
||||
def predict(self, *args, **kwargs):
|
||||
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
|
||||
return predict(*args, **kwargs)
|
||||
|
||||
def predict_proba(self, *args, **kwargs):
|
||||
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
|
||||
return predict(*args, **kwargs)
|
||||
|
||||
|
||||
class MlflowModelReader:
|
||||
|
||||
def __init__(self, mlruns_dir=None):
|
||||
self.mlruns_dir = mlruns_dir
|
||||
mlflow.set_tracking_uri(self.mlruns_dir)
|
||||
|
||||
@staticmethod
|
||||
def __correct_artifact_uri(run_artifact_uri, base_path):
|
||||
_, suffix = run_artifact_uri.split("mlruns/")
|
||||
return os.path.join(base_path, suffix)
|
||||
|
||||
def __get_weights_path(self, run_id, prefix="tt"):
|
||||
run = self.__get_run(run_id)
|
||||
|
||||
artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir)
|
||||
path = os.path.join(artifact_uri, prefix, "train_dev", "estimator")
|
||||
|
||||
base_path = os.path.join(path, "base_weights.h5")
|
||||
weights_path = os.path.join(path, "weights.h5")
|
||||
|
||||
return base_path, weights_path
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def __get_run(self, run_id):
|
||||
return mlflow.get_run(run_id)
|
||||
|
||||
def __get_classes(self, run_id, prefix="tt"):
|
||||
run = self.__get_run(run_id)
|
||||
|
||||
classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"'))
|
||||
|
||||
return classes
|
||||
|
||||
def __get_model_handle(self, run_id):
|
||||
run = self.__get_run(run_id)
|
||||
|
||||
model_handle_builder = load_object(run.data.params["model_handle_builder"].strip())
|
||||
|
||||
base_weights_path, weights_path = self.__get_weights_path(run_id)
|
||||
|
||||
model_handle = model_handle_builder(self.__get_classes(run_id), base_weights=base_weights_path)
|
||||
model_handle.load_top_weights(weights_path)
|
||||
|
||||
return model_handle
|
||||
|
||||
def __get_model(self, run_id) -> PredictionModelHandle:
|
||||
model_handle = self.__get_model_handle(run_id)
|
||||
model = PredictionModelHandle(model_handle)
|
||||
return model
|
||||
|
||||
def __getitem__(self, run_id):
|
||||
return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)}
|
||||
|
||||
|
||||
def load_object(object_path):
|
||||
path_fragments = object_path.split(".")
|
||||
|
||||
module_path = ".".join(path_fragments[:-1])
|
||||
object_name = path_fragments[-1]
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, object_name)
|
||||
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
||||
|
||||
|
||||
class MlflowConnector(DatabaseConnector):
|
||||
|
||||
0
image_prediction/redai_adapter/__init__.py
Normal file
0
image_prediction/redai_adapter/__init__.py
Normal file
72
image_prediction/redai_adapter/mlflow.py
Normal file
72
image_prediction/redai_adapter/mlflow.py
Normal file
@ -0,0 +1,72 @@
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import mlflow
|
||||
|
||||
from image_prediction.redai_adapter.model import PredictionModelHandle
|
||||
|
||||
|
||||
class MlflowModelReader:
|
||||
|
||||
def __init__(self, mlruns_dir=None):
|
||||
self.mlruns_dir = mlruns_dir
|
||||
mlflow.set_tracking_uri(self.mlruns_dir)
|
||||
|
||||
@staticmethod
|
||||
def __correct_artifact_uri(run_artifact_uri, base_path):
|
||||
_, suffix = run_artifact_uri.split("mlruns/")
|
||||
return os.path.join(base_path, suffix)
|
||||
|
||||
def __get_weights_path(self, run_id, prefix="tt"):
|
||||
run = self.__get_run(run_id)
|
||||
|
||||
artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir)
|
||||
path = os.path.join(artifact_uri, prefix, "train_dev", "estimator")
|
||||
|
||||
base_path = os.path.join(path, "base_weights.h5")
|
||||
weights_path = os.path.join(path, "weights.h5")
|
||||
|
||||
return base_path, weights_path
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def __get_run(self, run_id):
|
||||
return mlflow.get_run(run_id)
|
||||
|
||||
def __get_classes(self, run_id, prefix="tt"):
|
||||
run = self.__get_run(run_id)
|
||||
|
||||
classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"'))
|
||||
|
||||
return classes
|
||||
|
||||
def __get_model_handle(self, run_id):
|
||||
run = self.__get_run(run_id)
|
||||
|
||||
model_handle_builder = load_object(run.data.params["model_handle_builder"].strip())
|
||||
|
||||
base_weights_path, weights_path = self.__get_weights_path(run_id)
|
||||
|
||||
model_handle = model_handle_builder(self.__get_classes(run_id), base_weights=base_weights_path)
|
||||
model_handle.load_top_weights(weights_path)
|
||||
|
||||
return model_handle
|
||||
|
||||
def __get_model(self, run_id) -> PredictionModelHandle:
|
||||
model_handle = self.__get_model_handle(run_id)
|
||||
model = PredictionModelHandle(model_handle)
|
||||
return model
|
||||
|
||||
def __getitem__(self, run_id):
|
||||
return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)}
|
||||
|
||||
|
||||
def load_object(object_path):
|
||||
path_fragments = object_path.split(".")
|
||||
|
||||
module_path = ".".join(path_fragments[:-1])
|
||||
object_name = path_fragments[-1]
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, object_name)
|
||||
16
image_prediction/redai_adapter/model.py
Normal file
16
image_prediction/redai_adapter/model.py
Normal file
@ -0,0 +1,16 @@
|
||||
from funcy import rcompose
|
||||
|
||||
|
||||
class PredictionModelHandle:
|
||||
"""Simplifies usage of ModelHandle instances for prediction purposes."""
|
||||
|
||||
def __init__(self, model_handle):
|
||||
self.__model_handle = model_handle
|
||||
|
||||
def predict(self, *args, **kwargs):
|
||||
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
|
||||
return predict(*args, **kwargs)
|
||||
|
||||
def predict_proba(self, *args, **kwargs):
|
||||
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
|
||||
return predict(*args, **kwargs)
|
||||
@ -20,7 +20,8 @@ from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
||||
from image_prediction.model_loader.loader import ModelLoader
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowConnector, MlflowModelReader
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from image_prediction.model_loader.loaders.mlflow import PredictionModelHandle
|
||||
from image_prediction.redai_adapter.model import PredictionModelHandle
|
||||
|
||||
|
||||
@pytest.mark.parametrize("database_type", ["mock"])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user