restructuring of modules

This commit is contained in:
Matthias Bisping 2022-03-29 20:02:40 +02:00
parent d33a882d65
commit 358d7ecd91
6 changed files with 92 additions and 102 deletions

View File

@ -1,104 +1,5 @@
"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
based solution or look into an alternative such as WandB or use a platform solution such as AWS.
"""
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import importlib
import json
import os
from functools import lru_cache
from funcy import rcompose
from image_prediction.model_loader.database.connector import DatabaseConnector from image_prediction.model_loader.database.connector import DatabaseConnector
from image_prediction.redai_adapter.mlflow import MlflowModelReader
import mlflow
class PredictionModelHandle:
"""Simplifies usage of ModelHandle instances for prediction purposes."""
def __init__(self, model_handle):
self.__model_handle = model_handle
def predict(self, *args, **kwargs):
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
return predict(*args, **kwargs)
def predict_proba(self, *args, **kwargs):
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
return predict(*args, **kwargs)
class MlflowModelReader:
def __init__(self, mlruns_dir=None):
self.mlruns_dir = mlruns_dir
mlflow.set_tracking_uri(self.mlruns_dir)
@staticmethod
def __correct_artifact_uri(run_artifact_uri, base_path):
_, suffix = run_artifact_uri.split("mlruns/")
return os.path.join(base_path, suffix)
def __get_weights_path(self, run_id, prefix="tt"):
run = self.__get_run(run_id)
artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir)
path = os.path.join(artifact_uri, prefix, "train_dev", "estimator")
base_path = os.path.join(path, "base_weights.h5")
weights_path = os.path.join(path, "weights.h5")
return base_path, weights_path
@lru_cache(maxsize=None)
def __get_run(self, run_id):
return mlflow.get_run(run_id)
def __get_classes(self, run_id, prefix="tt"):
run = self.__get_run(run_id)
classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"'))
return classes
def __get_model_handle(self, run_id):
run = self.__get_run(run_id)
model_handle_builder = load_object(run.data.params["model_handle_builder"].strip())
base_weights_path, weights_path = self.__get_weights_path(run_id)
model_handle = model_handle_builder(self.__get_classes(run_id), base_weights=base_weights_path)
model_handle.load_top_weights(weights_path)
return model_handle
def __get_model(self, run_id) -> PredictionModelHandle:
model_handle = self.__get_model_handle(run_id)
model = PredictionModelHandle(model_handle)
return model
def __getitem__(self, run_id):
return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)}
def load_object(object_path):
path_fragments = object_path.split(".")
module_path = ".".join(path_fragments[:-1])
object_name = path_fragments[-1]
module = importlib.import_module(module_path)
return getattr(module, object_name)
class MlflowConnector(DatabaseConnector): class MlflowConnector(DatabaseConnector):

View File

@ -0,0 +1,72 @@
import importlib
import json
import os
from functools import lru_cache
import mlflow
from image_prediction.redai_adapter.model import PredictionModelHandle
class MlflowModelReader:
def __init__(self, mlruns_dir=None):
self.mlruns_dir = mlruns_dir
mlflow.set_tracking_uri(self.mlruns_dir)
@staticmethod
def __correct_artifact_uri(run_artifact_uri, base_path):
_, suffix = run_artifact_uri.split("mlruns/")
return os.path.join(base_path, suffix)
def __get_weights_path(self, run_id, prefix="tt"):
run = self.__get_run(run_id)
artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir)
path = os.path.join(artifact_uri, prefix, "train_dev", "estimator")
base_path = os.path.join(path, "base_weights.h5")
weights_path = os.path.join(path, "weights.h5")
return base_path, weights_path
@lru_cache(maxsize=None)
def __get_run(self, run_id):
return mlflow.get_run(run_id)
def __get_classes(self, run_id, prefix="tt"):
run = self.__get_run(run_id)
classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"'))
return classes
def __get_model_handle(self, run_id):
run = self.__get_run(run_id)
model_handle_builder = load_object(run.data.params["model_handle_builder"].strip())
base_weights_path, weights_path = self.__get_weights_path(run_id)
model_handle = model_handle_builder(self.__get_classes(run_id), base_weights=base_weights_path)
model_handle.load_top_weights(weights_path)
return model_handle
def __get_model(self, run_id) -> PredictionModelHandle:
model_handle = self.__get_model_handle(run_id)
model = PredictionModelHandle(model_handle)
return model
def __getitem__(self, run_id):
return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)}
def load_object(object_path):
path_fragments = object_path.split(".")
module_path = ".".join(path_fragments[:-1])
object_name = path_fragments[-1]
module = importlib.import_module(module_path)
return getattr(module, object_name)

View File

@ -0,0 +1,16 @@
from funcy import rcompose
class PredictionModelHandle:
"""Simplifies usage of ModelHandle instances for prediction purposes."""
def __init__(self, model_handle):
self.__model_handle = model_handle
def predict(self, *args, **kwargs):
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
return predict(*args, **kwargs)
def predict_proba(self, *args, **kwargs):
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
return predict(*args, **kwargs)

View File

@ -20,7 +20,8 @@ from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector, MlflowModelReader from image_prediction.model_loader.loaders.mlflow import MlflowConnector
from image_prediction.redai_adapter.mlflow import MlflowModelReader
@pytest.fixture @pytest.fixture

View File

@ -1,6 +1,6 @@
import pytest import pytest
from image_prediction.model_loader.loaders.mlflow import PredictionModelHandle from image_prediction.redai_adapter.model import PredictionModelHandle
@pytest.mark.parametrize("database_type", ["mock"]) @pytest.mark.parametrize("database_type", ["mock"])