From 358d7ecd91ac3a394cd4f439b1e4da6569295426 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Tue, 29 Mar 2022 20:02:40 +0200 Subject: [PATCH] restructuring of modules --- .../model_loader/loaders/mlflow.py | 101 +----------------- image_prediction/redai_adapter/__init__.py | 0 image_prediction/redai_adapter/mlflow.py | 72 +++++++++++++ image_prediction/redai_adapter/model.py | 16 +++ test/unit_tests/conftest.py | 3 +- test/unit_tests/model_loader_test.py | 2 +- 6 files changed, 92 insertions(+), 102 deletions(-) create mode 100644 image_prediction/redai_adapter/__init__.py create mode 100644 image_prediction/redai_adapter/mlflow.py create mode 100644 image_prediction/redai_adapter/model.py diff --git a/image_prediction/model_loader/loaders/mlflow.py b/image_prediction/model_loader/loaders/mlflow.py index ada9899..9b6a044 100644 --- a/image_prediction/model_loader/loaders/mlflow.py +++ b/image_prediction/model_loader/loaders/mlflow.py @@ -1,104 +1,5 @@ -"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and -MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the -need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the -MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a -non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow -based solution or look into an alternative such as WandB or use a platform solution such as AWS. -""" - -import os - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - -import importlib -import json -import os -from functools import lru_cache - -from funcy import rcompose - from image_prediction.model_loader.database.connector import DatabaseConnector - -import mlflow - - -class PredictionModelHandle: - """Simplifies usage of ModelHandle instances for prediction purposes.""" - - def __init__(self, model_handle): - self.__model_handle = model_handle - - def predict(self, *args, **kwargs): - predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict) - return predict(*args, **kwargs) - - def predict_proba(self, *args, **kwargs): - predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba) - return predict(*args, **kwargs) - - -class MlflowModelReader: - - def __init__(self, mlruns_dir=None): - self.mlruns_dir = mlruns_dir - mlflow.set_tracking_uri(self.mlruns_dir) - - @staticmethod - def __correct_artifact_uri(run_artifact_uri, base_path): - _, suffix = run_artifact_uri.split("mlruns/") - return os.path.join(base_path, suffix) - - def __get_weights_path(self, run_id, prefix="tt"): - run = self.__get_run(run_id) - - artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir) - path = os.path.join(artifact_uri, prefix, "train_dev", "estimator") - - base_path = os.path.join(path, "base_weights.h5") - weights_path = os.path.join(path, "weights.h5") - - return base_path, weights_path - - @lru_cache(maxsize=None) - def __get_run(self, run_id): - return mlflow.get_run(run_id) - - def __get_classes(self, run_id, prefix="tt"): - run = self.__get_run(run_id) - - classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"')) - - return classes - - def __get_model_handle(self, run_id): - run = self.__get_run(run_id) - - model_handle_builder = load_object(run.data.params["model_handle_builder"].strip()) - - base_weights_path, weights_path = self.__get_weights_path(run_id) - - model_handle = model_handle_builder(self.__get_classes(run_id), base_weights=base_weights_path) - model_handle.load_top_weights(weights_path) - - return model_handle - - def __get_model(self, run_id) -> PredictionModelHandle: - model_handle = self.__get_model_handle(run_id) - model = PredictionModelHandle(model_handle) - return model - - def __getitem__(self, run_id): - return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)} - - -def load_object(object_path): - path_fragments = object_path.split(".") - - module_path = ".".join(path_fragments[:-1]) - object_name = path_fragments[-1] - - module = importlib.import_module(module_path) - return getattr(module, object_name) +from image_prediction.redai_adapter.mlflow import MlflowModelReader class MlflowConnector(DatabaseConnector): diff --git a/image_prediction/redai_adapter/__init__.py b/image_prediction/redai_adapter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/redai_adapter/mlflow.py b/image_prediction/redai_adapter/mlflow.py new file mode 100644 index 0000000..23231e0 --- /dev/null +++ b/image_prediction/redai_adapter/mlflow.py @@ -0,0 +1,72 @@ +import importlib +import json +import os +from functools import lru_cache + +import mlflow + +from image_prediction.redai_adapter.model import PredictionModelHandle + + +class MlflowModelReader: + + def __init__(self, mlruns_dir=None): + self.mlruns_dir = mlruns_dir + mlflow.set_tracking_uri(self.mlruns_dir) + + @staticmethod + def __correct_artifact_uri(run_artifact_uri, base_path): + _, suffix = run_artifact_uri.split("mlruns/") + return os.path.join(base_path, suffix) + + def __get_weights_path(self, run_id, prefix="tt"): + run = self.__get_run(run_id) + + artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir) + path = os.path.join(artifact_uri, prefix, "train_dev", "estimator") + + base_path = os.path.join(path, "base_weights.h5") + weights_path = os.path.join(path, "weights.h5") + + return base_path, weights_path + + @lru_cache(maxsize=None) + def __get_run(self, run_id): + return mlflow.get_run(run_id) + + def __get_classes(self, run_id, prefix="tt"): + run = self.__get_run(run_id) + + classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"')) + + return classes + + def __get_model_handle(self, run_id): + run = self.__get_run(run_id) + + model_handle_builder = load_object(run.data.params["model_handle_builder"].strip()) + + base_weights_path, weights_path = self.__get_weights_path(run_id) + + model_handle = model_handle_builder(self.__get_classes(run_id), base_weights=base_weights_path) + model_handle.load_top_weights(weights_path) + + return model_handle + + def __get_model(self, run_id) -> PredictionModelHandle: + model_handle = self.__get_model_handle(run_id) + model = PredictionModelHandle(model_handle) + return model + + def __getitem__(self, run_id): + return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)} + + +def load_object(object_path): + path_fragments = object_path.split(".") + + module_path = ".".join(path_fragments[:-1]) + object_name = path_fragments[-1] + + module = importlib.import_module(module_path) + return getattr(module, object_name) diff --git a/image_prediction/redai_adapter/model.py b/image_prediction/redai_adapter/model.py new file mode 100644 index 0000000..d646f24 --- /dev/null +++ b/image_prediction/redai_adapter/model.py @@ -0,0 +1,16 @@ +from funcy import rcompose + + +class PredictionModelHandle: + """Simplifies usage of ModelHandle instances for prediction purposes.""" + + def __init__(self, model_handle): + self.__model_handle = model_handle + + def predict(self, *args, **kwargs): + predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict) + return predict(*args, **kwargs) + + def predict_proba(self, *args, **kwargs): + predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba) + return predict(*args, **kwargs) \ No newline at end of file diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index a81d7e0..cb8e3b2 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -20,7 +20,8 @@ from image_prediction.image_extractor.extractors.mock import ImageExtractorMock from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock from image_prediction.model_loader.loader import ModelLoader -from image_prediction.model_loader.loaders.mlflow import MlflowConnector, MlflowModelReader +from image_prediction.model_loader.loaders.mlflow import MlflowConnector +from image_prediction.redai_adapter.mlflow import MlflowModelReader @pytest.fixture diff --git a/test/unit_tests/model_loader_test.py b/test/unit_tests/model_loader_test.py index 0216e36..d70f27b 100644 --- a/test/unit_tests/model_loader_test.py +++ b/test/unit_tests/model_loader_test.py @@ -1,6 +1,6 @@ import pytest -from image_prediction.model_loader.loaders.mlflow import PredictionModelHandle +from image_prediction.redai_adapter.model import PredictionModelHandle @pytest.mark.parametrize("database_type", ["mock"])