restructuring of modules
This commit is contained in:
parent
d33a882d65
commit
358d7ecd91
@ -1,104 +1,5 @@
|
|||||||
"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
|
|
||||||
MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
|
|
||||||
need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
|
|
||||||
MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
|
|
||||||
non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
|
|
||||||
based solution or look into an alternative such as WandB or use a platform solution such as AWS.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
from funcy import rcompose
|
|
||||||
|
|
||||||
from image_prediction.model_loader.database.connector import DatabaseConnector
|
from image_prediction.model_loader.database.connector import DatabaseConnector
|
||||||
|
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
||||||
import mlflow
|
|
||||||
|
|
||||||
|
|
||||||
class PredictionModelHandle:
|
|
||||||
"""Simplifies usage of ModelHandle instances for prediction purposes."""
|
|
||||||
|
|
||||||
def __init__(self, model_handle):
|
|
||||||
self.__model_handle = model_handle
|
|
||||||
|
|
||||||
def predict(self, *args, **kwargs):
|
|
||||||
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
|
|
||||||
return predict(*args, **kwargs)
|
|
||||||
|
|
||||||
def predict_proba(self, *args, **kwargs):
|
|
||||||
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
|
|
||||||
return predict(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class MlflowModelReader:
|
|
||||||
|
|
||||||
def __init__(self, mlruns_dir=None):
|
|
||||||
self.mlruns_dir = mlruns_dir
|
|
||||||
mlflow.set_tracking_uri(self.mlruns_dir)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __correct_artifact_uri(run_artifact_uri, base_path):
|
|
||||||
_, suffix = run_artifact_uri.split("mlruns/")
|
|
||||||
return os.path.join(base_path, suffix)
|
|
||||||
|
|
||||||
def __get_weights_path(self, run_id, prefix="tt"):
|
|
||||||
run = self.__get_run(run_id)
|
|
||||||
|
|
||||||
artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir)
|
|
||||||
path = os.path.join(artifact_uri, prefix, "train_dev", "estimator")
|
|
||||||
|
|
||||||
base_path = os.path.join(path, "base_weights.h5")
|
|
||||||
weights_path = os.path.join(path, "weights.h5")
|
|
||||||
|
|
||||||
return base_path, weights_path
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
|
||||||
def __get_run(self, run_id):
|
|
||||||
return mlflow.get_run(run_id)
|
|
||||||
|
|
||||||
def __get_classes(self, run_id, prefix="tt"):
|
|
||||||
run = self.__get_run(run_id)
|
|
||||||
|
|
||||||
classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"'))
|
|
||||||
|
|
||||||
return classes
|
|
||||||
|
|
||||||
def __get_model_handle(self, run_id):
|
|
||||||
run = self.__get_run(run_id)
|
|
||||||
|
|
||||||
model_handle_builder = load_object(run.data.params["model_handle_builder"].strip())
|
|
||||||
|
|
||||||
base_weights_path, weights_path = self.__get_weights_path(run_id)
|
|
||||||
|
|
||||||
model_handle = model_handle_builder(self.__get_classes(run_id), base_weights=base_weights_path)
|
|
||||||
model_handle.load_top_weights(weights_path)
|
|
||||||
|
|
||||||
return model_handle
|
|
||||||
|
|
||||||
def __get_model(self, run_id) -> PredictionModelHandle:
|
|
||||||
model_handle = self.__get_model_handle(run_id)
|
|
||||||
model = PredictionModelHandle(model_handle)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def __getitem__(self, run_id):
|
|
||||||
return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)}
|
|
||||||
|
|
||||||
|
|
||||||
def load_object(object_path):
|
|
||||||
path_fragments = object_path.split(".")
|
|
||||||
|
|
||||||
module_path = ".".join(path_fragments[:-1])
|
|
||||||
object_name = path_fragments[-1]
|
|
||||||
|
|
||||||
module = importlib.import_module(module_path)
|
|
||||||
return getattr(module, object_name)
|
|
||||||
|
|
||||||
|
|
||||||
class MlflowConnector(DatabaseConnector):
|
class MlflowConnector(DatabaseConnector):
|
||||||
|
|||||||
0
image_prediction/redai_adapter/__init__.py
Normal file
0
image_prediction/redai_adapter/__init__.py
Normal file
72
image_prediction/redai_adapter/mlflow.py
Normal file
72
image_prediction/redai_adapter/mlflow.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import mlflow
|
||||||
|
|
||||||
|
from image_prediction.redai_adapter.model import PredictionModelHandle
|
||||||
|
|
||||||
|
|
||||||
|
class MlflowModelReader:
|
||||||
|
|
||||||
|
def __init__(self, mlruns_dir=None):
|
||||||
|
self.mlruns_dir = mlruns_dir
|
||||||
|
mlflow.set_tracking_uri(self.mlruns_dir)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __correct_artifact_uri(run_artifact_uri, base_path):
|
||||||
|
_, suffix = run_artifact_uri.split("mlruns/")
|
||||||
|
return os.path.join(base_path, suffix)
|
||||||
|
|
||||||
|
def __get_weights_path(self, run_id, prefix="tt"):
|
||||||
|
run = self.__get_run(run_id)
|
||||||
|
|
||||||
|
artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir)
|
||||||
|
path = os.path.join(artifact_uri, prefix, "train_dev", "estimator")
|
||||||
|
|
||||||
|
base_path = os.path.join(path, "base_weights.h5")
|
||||||
|
weights_path = os.path.join(path, "weights.h5")
|
||||||
|
|
||||||
|
return base_path, weights_path
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def __get_run(self, run_id):
|
||||||
|
return mlflow.get_run(run_id)
|
||||||
|
|
||||||
|
def __get_classes(self, run_id, prefix="tt"):
|
||||||
|
run = self.__get_run(run_id)
|
||||||
|
|
||||||
|
classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"'))
|
||||||
|
|
||||||
|
return classes
|
||||||
|
|
||||||
|
def __get_model_handle(self, run_id):
|
||||||
|
run = self.__get_run(run_id)
|
||||||
|
|
||||||
|
model_handle_builder = load_object(run.data.params["model_handle_builder"].strip())
|
||||||
|
|
||||||
|
base_weights_path, weights_path = self.__get_weights_path(run_id)
|
||||||
|
|
||||||
|
model_handle = model_handle_builder(self.__get_classes(run_id), base_weights=base_weights_path)
|
||||||
|
model_handle.load_top_weights(weights_path)
|
||||||
|
|
||||||
|
return model_handle
|
||||||
|
|
||||||
|
def __get_model(self, run_id) -> PredictionModelHandle:
|
||||||
|
model_handle = self.__get_model_handle(run_id)
|
||||||
|
model = PredictionModelHandle(model_handle)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def __getitem__(self, run_id):
|
||||||
|
return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)}
|
||||||
|
|
||||||
|
|
||||||
|
def load_object(object_path):
|
||||||
|
path_fragments = object_path.split(".")
|
||||||
|
|
||||||
|
module_path = ".".join(path_fragments[:-1])
|
||||||
|
object_name = path_fragments[-1]
|
||||||
|
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
return getattr(module, object_name)
|
||||||
16
image_prediction/redai_adapter/model.py
Normal file
16
image_prediction/redai_adapter/model.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from funcy import rcompose
|
||||||
|
|
||||||
|
|
||||||
|
class PredictionModelHandle:
|
||||||
|
"""Simplifies usage of ModelHandle instances for prediction purposes."""
|
||||||
|
|
||||||
|
def __init__(self, model_handle):
|
||||||
|
self.__model_handle = model_handle
|
||||||
|
|
||||||
|
def predict(self, *args, **kwargs):
|
||||||
|
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
|
||||||
|
return predict(*args, **kwargs)
|
||||||
|
|
||||||
|
def predict_proba(self, *args, **kwargs):
|
||||||
|
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
|
||||||
|
return predict(*args, **kwargs)
|
||||||
@ -20,7 +20,8 @@ from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
|||||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||||
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
||||||
from image_prediction.model_loader.loader import ModelLoader
|
from image_prediction.model_loader.loader import ModelLoader
|
||||||
from image_prediction.model_loader.loaders.mlflow import MlflowConnector, MlflowModelReader
|
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||||
|
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from image_prediction.model_loader.loaders.mlflow import PredictionModelHandle
|
from image_prediction.redai_adapter.model import PredictionModelHandle
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("database_type", ["mock"])
|
@pytest.mark.parametrize("database_type", ["mock"])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user