diff --git a/image_prediction/exceptions.py b/image_prediction/exceptions.py index fe0136c..9cc0f5d 100644 --- a/image_prediction/exceptions.py +++ b/image_prediction/exceptions.py @@ -10,5 +10,9 @@ class UnknownModelLoader(ValueError): pass +class UnknownDatabaseType(ValueError): + pass + + class IncorrectInstantiation(RuntimeError): pass diff --git a/image_prediction/model_loader/database/connectors/mock.py b/image_prediction/model_loader/database/connectors/mock.py new file mode 100644 index 0000000..6bf1199 --- /dev/null +++ b/image_prediction/model_loader/database/connectors/mock.py @@ -0,0 +1,10 @@ +from image_prediction.model_loader.database.connector import DatabaseConnector + + +class DatabaseConnectorMock(DatabaseConnector): + + def __init__(self, store: dict): + self.store = store + + def get_object(self, identifier): + return self.store[identifier] diff --git a/image_prediction/model_loader/loader.py b/image_prediction/model_loader/loader.py index 32e76d6..41cc5c2 100644 --- a/image_prediction/model_loader/loader.py +++ b/image_prediction/model_loader/loader.py @@ -1,12 +1,19 @@ -import abc +from functools import lru_cache + +from image_prediction.model_loader.database.connector import DatabaseConnector -class ModelLoader(abc.ABC): +class ModelLoader: - @abc.abstractmethod - def load_model(self, *args, **kwargs): - pass + def __init__(self, database_connector: DatabaseConnector): + self.database_connector = database_connector - @abc.abstractmethod - def load_classes(self, *args, **kwargs): - pass + @lru_cache(maxsize=None) + def __get_object(self, identifier): + return self.database_connector.get_object(identifier) + + def load_model(self, identifier): + return self.__get_object(identifier)["model"] + + def load_classes(self, identifier): + return self.__get_object(identifier)["classes"] diff --git a/image_prediction/model_loader/loaders/mlflow.py b/image_prediction/model_loader/loaders/mlflow.py index 3328b5d..d0b98b9 100644 --- a/image_prediction/model_loader/loaders/mlflow.py +++ b/image_prediction/model_loader/loaders/mlflow.py @@ -1,101 +1,123 @@ -"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and -MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the -need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the -MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a -non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow -based solution or look into an alternative such as WandB or use a platform solution such as AWS. -""" -import importlib -import json -import os -import warnings - -import numpy as np - -from image_prediction.exceptions import IncorrectInstantiation -from image_prediction.model_loader.loader import ModelLoader - -warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") - -import mlflow - - -def load_object(object_path): - path_fragments = object_path.split(".") - - module_path = ".".join(path_fragments[:-1]) - object_name = path_fragments[-1] - - module = importlib.import_module(module_path) - return getattr(module, object_name) - - -def to_local_path(uri): - return uri[7:] - - -class MlflowModelReader: - - def __init__(self, run_id, mlruns_dir=None): - mlflow.set_tracking_uri(mlruns_dir) - - self.run_id = run_id - self.run = mlflow.get_run(run_id) - self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir) - - @staticmethod - def __correct_artifact_uri(run_artifact_uri, base_path): - _, suffix = run_artifact_uri.split("mlruns/") - return os.path.join(base_path, suffix) - - def get_weights_path(self, prefix="tt"): - path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5") - return path - - def get_classes(self, prefix="tt"): - classes = json.loads( - self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"') - ) - return classes - - def get_model_handle(self, base_weights=None): - weights_path = self.get_weights_path() - model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip()) - model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights) - model_handle.load_top_weights(weights_path) - return model_handle - - -class MlflowLoader(ModelLoader): - - def __init__(self, mlruns_dir): - self.__mlruns_dir = mlruns_dir - self._model_handle = None - self.__last_run_id = None - self._base_weights = None - - def load_model(self, run_id, base_weights=None): - - if not base_weights: - - if not self._base_weights: - raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.") - - base_weights = self._base_weights - - if not self._model_handle and run_id == self.__last_run_id: - mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir) - model_handel = mlflow_reader.get_model_handle(base_weights) - self._model_handle = model_handel - self.__last_run_id = run_id - - return self._model_handle - - def load_classes(self, run_id): - model_handle = self.load_model(run_id) - - classes = model_handle.model.classes_ - classes_readable = np.array(model_handle.classes) - classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]] - - return classes_readable_aligned +# """This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and +# MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the +# need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the +# MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a +# non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow +# based solution or look into an alternative such as WandB or use a platform solution such as AWS. +# """ +# import importlib +# import json +# import os +# import warnings +# from typing import Mapping +# +# import numpy as np +# from funcy import rcompose +# +# from image_prediction.exceptions import IncorrectInstantiation +# from image_prediction.model_loader.loader import ModelLoader +# +# warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") +# +# import mlflow +# +# +# def load_object(object_path): +# path_fragments = object_path.split(".") +# +# module_path = ".".join(path_fragments[:-1]) +# object_name = path_fragments[-1] +# +# module = importlib.import_module(module_path) +# return getattr(module, object_name) +# +# +# def to_local_path(uri): +# return uri[7:] +# +# +# class MlflowModelReader: +# +# def __init__(self, run_id, mlruns_dir=None): +# mlflow.set_tracking_uri(mlruns_dir) +# +# self.run_id = run_id +# self.run = mlflow.get_run(run_id) +# self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir) +# +# @staticmethod +# def __correct_artifact_uri(run_artifact_uri, base_path): +# _, suffix = run_artifact_uri.split("mlruns/") +# return os.path.join(base_path, suffix) +# +# def get_weights_path(self, prefix="tt"): +# path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5") +# return path +# +# def get_classes(self, prefix="tt"): +# classes = json.loads( +# self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"') +# ) +# return classes +# +# def get_model_handle(self, base_weights=None): +# weights_path = self.get_weights_path() +# model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip()) +# model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights) +# model_handle.load_top_weights(weights_path) +# return model_handle +# +# +# class PredictionModelHandle: +# """Simplifies usage of ModelHandle instances for prediction purposes.""" +# +# def __init__(self, model_handle, classes_readable: Mapping[int, str]): +# self.__model_handle = model_handle +# self.__classes_readable = classes_readable +# +# @property +# def classes(self): +# return self.__classes_readable +# +# def predict(self, *args, **kwargs): +# predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict) +# return predict(*args, **kwargs) +# +# def predict_proba(self, *args, **kwargs): +# predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba) +# return predict(*args, **kwargs) +# +# +# class MlflowLoader(ModelLoader): +# +# def __init__(self, mlruns_dir): +# self.__mlruns_dir = mlruns_dir +# self._base_weights = None +# +# def load_model(self, run_id, base_weights=None) -> PredictionModelHandle: +# +# # TODO: refac https://stackoverflow.com/questions/42735421/how-to-restrict-object-instantiation-only-via-a-factory-in-python +# if not base_weights: +# +# if not self._base_weights: +# raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.") +# +# base_weights = self._base_weights +# +# mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir) +# model_handel = mlflow_reader.get_model_handle(base_weights) +# model_handle = model_handel +# classes_readable = self.__load_classes(model_handle) +# +# model = PredictionModelHandle(model_handle, classes_readable) +# +# return model +# +# @staticmethod +# def __load_classes(model_handle): +# +# classes = model_handle.model.classes_ +# classes_readable = np.array(model_handle.classes) +# classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]] +# +# return classes_readable_aligned diff --git a/image_prediction/model_loader/loaders/mock.py b/image_prediction/model_loader/loaders/mock.py index be9ff82..6269261 100644 --- a/image_prediction/model_loader/loaders/mock.py +++ b/image_prediction/model_loader/loaders/mock.py @@ -9,7 +9,3 @@ class ModelLoaderMock(ModelLoader): def load_model(self, identifier): assert self.model is not None, "Set the model to be returned first via monkeypatching" return self.model - - def load_classes(self, identifier): - assert self.classes is not None, "Set the classes to be returned first via monkeypatching" - return self.classes diff --git a/image_prediction/model_loading.py b/image_prediction/model_loading.py deleted file mode 100644 index a3f0a32..0000000 --- a/image_prediction/model_loading.py +++ /dev/null @@ -1,17 +0,0 @@ -from collections import namedtuple - -from image_prediction.locations import MLRUNS_DIR -from image_prediction.model_loader.loader import ModelLoader -from image_prediction.model_loader.loaders.mlflow import MlflowLoader - -ModelClassesPair = namedtuple("ModelClassesPair", ["model", "classes"]) - - -def load_model_and_classes(identifier, model_loader: ModelLoader = None) -> ModelClassesPair: - if not model_loader: - model_loader = MlflowLoader(MLRUNS_DIR) - - model = model_loader.load_model(identifier) - classes = model_loader.load_classes(identifier) - - return ModelClassesPair(model, classes) diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index a281cc7..b420753 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -1,4 +1,5 @@ import random +import string import tempfile from itertools import starmap from operator import itemgetter @@ -13,13 +14,12 @@ from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor -from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownModelLoader +from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractors.mock import ImageExtractorMock from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor -from image_prediction.model_loader.loaders.loaders import get_mlflow_loader -from image_prediction.model_loader.loaders.mlflow import MlflowLoader -from image_prediction.model_loader.loaders.mock import ModelLoaderMock +from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock +from image_prediction.model_loader.loader import ModelLoader @pytest.fixture @@ -207,29 +207,79 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair): pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type="png") -@pytest.fixture -def model_handle_mock(classes, classifier): - - class ModelHandleMock: - - def __init__(self, classes): - classifier.classes_ = np.array(list(range(len(classes)))) - self.classes = classes - self.model = classifier - - return ModelHandleMock(classes) +# @pytest.fixture +# def model_handle_mock(classes, classifier): +# +# class ModelHandleMock: +# +# def __init__(self, classes): +# classifier.classes_ = np.array(list(range(len(classes)))) +# self.classes = classes +# self.model = classifier +# +# return ModelHandleMock(classes) +# +# +# @pytest.fixture +# def prediction_model_handle_mock(model_handle_mock, classes): +# return PredictionModelHandle(model_handle_mock, classes) @pytest.fixture -def model_loader(loader_type, monkeypatch, model_handle_mock, classes): - if loader_type == "mock": - loader = ModelLoaderMock() - monkeypatch.setattr(loader, "model", model_handle_mock) - monkeypatch.setattr(loader, "classes", classes) - elif loader_type == "mlflow": - loader = get_mlflow_loader() - monkeypatch.setattr(loader, "_model_handle", model_handle_mock) +def model(): + + class Model: + + @staticmethod + def predict(*args): + return True + + @staticmethod + def predict_proba(*args): + return True + + return Model() + + +@pytest.fixture +def model_database_record_identifier(): + return "".join(random.sample(string.ascii_letters, k=10)) + + +@pytest.fixture +def model_database_record(model, classes): + return {"model": model, "classes": classes} + + +@pytest.fixture +def model_database(model_database_record, model_database_record_identifier): + return {model_database_record_identifier: model_database_record} + + +@pytest.fixture +def database_connector(database_type, model_database): + if database_type == "mock": + return DatabaseConnectorMock(model_database) else: - raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.") + raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.") - return loader + +@pytest.fixture +def model_loader(database_connector): + return ModelLoader(database_connector) + + +# @pytest.fixture +# def model_loader(loader_type, monkeypatch, model_handle_mock, classes): +# if loader_type == "mock": +# loader = ModelLoaderMock() +# monkeypatch.setattr(loader, "model", model_handle_mock) +# +# # elif loader_type == "mlflow": +# # loader = get_mlflow_loader() +# # monkeypatch.setattr(loader, "_model_handle", model_handle_mock) +# +# else: +# raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.") +# +# return loader diff --git a/test/unit_tests/model_loader_test.py b/test/unit_tests/model_loader_test.py index 8b4a64a..ecba32b 100644 --- a/test/unit_tests/model_loader_test.py +++ b/test/unit_tests/model_loader_test.py @@ -1,15 +1,19 @@ import numpy as np import pytest -from image_prediction.model_loading import load_model_and_classes +# @pytest.mark.parametrize("loader_type", ["mock"]) +# @pytest.mark.parametrize("estimator_type", ["mock"]) +# @pytest.mark.parametrize("batch_size", [3]) +# def test_load_model_and_classes(model_loader, model_handle_mock, classes): +# model_loaded, classes_loaded = model_loader.load_model_and_classes("an identifier") +# assert model_loaded == model_handle_mock +# assert np.all(classes_loaded == classes) -@pytest.mark.parametrize("loader_type", ["mock", "mlflow"]) -@pytest.mark.parametrize("estimator_type", ["mock"]) -@pytest.mark.parametrize("batch_size", [3]) -def test_load_model_and_classes(model_loader, model_handle_mock, classes): - # Load twice to test caching logic - for _ in range(2): - model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader) - assert model_loaded == model_handle_mock - assert np.all(classes_loaded == classes) +@pytest.mark.parametrize("database_type", ["mock"]) +def test_load_model_and_classes(model_loader, model_database_record_identifier, model, classes): + model_loaded = model_loader.load_model(model_database_record_identifier) + classes_loaded = model_loader.load_classes(model_database_record_identifier) + + assert model_loaded == model + assert classes_loaded == classes