diff --git a/image_prediction/exceptions.py b/image_prediction/exceptions.py index 3098a10..6982dff 100644 --- a/image_prediction/exceptions.py +++ b/image_prediction/exceptions.py @@ -4,3 +4,7 @@ class UnknownEstimatorAdapter(ValueError): class UnknownImageExtractor(ValueError): pass + + +class UnknownModelLoader(ValueError): + pass diff --git a/image_prediction/model_loader/__init__.py b/image_prediction/model_loader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/model_loader/loader.py b/image_prediction/model_loader/loader.py new file mode 100644 index 0000000..b78543e --- /dev/null +++ b/image_prediction/model_loader/loader.py @@ -0,0 +1,12 @@ +import abc + + +class ModelLoader(abc.ABC): + + @abc.abstractmethod + def load_model(self, identifier): + pass + + @abc.abstractmethod + def load_classes(self, identifier): + pass diff --git a/image_prediction/model_loader/loaders/__init__.py b/image_prediction/model_loader/loaders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/model_loader/loaders/mlflow.py b/image_prediction/model_loader/loaders/mlflow.py new file mode 100644 index 0000000..8495e30 --- /dev/null +++ b/image_prediction/model_loader/loaders/mlflow.py @@ -0,0 +1,83 @@ +import importlib +import json +import os +import warnings + +import numpy as np + +from image_prediction.locations import BASE_WEIGHTS +from image_prediction.model_loader.loader import ModelLoader + +warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") + +import mlflow + + +def load_object(object_path): + path_fragments = object_path.split(".") + + module_path = ".".join(path_fragments[:-1]) + object_name = path_fragments[-1] + + module = importlib.import_module(module_path) + return getattr(module, object_name) + + +def to_local_path(uri): + return uri[7:] + + +class MlflowModelReader: + + def __init__(self, run_id, mlruns_dir=None): + mlflow.set_tracking_uri(mlruns_dir) + + self.run_id = run_id + self.run = mlflow.get_run(run_id) + self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir) + + @staticmethod + def __correct_artifact_uri(run_artifact_uri, base_path): + _, suffix = run_artifact_uri.split("mlruns/") + return os.path.join(base_path, suffix) + + def get_weights_path(self, prefix="tt"): + path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5") + return path + + def get_classes(self, prefix="tt"): + classes = json.loads( + self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"') + ) + return classes + + def get_model_handle(self, base_weights=None): + weights_path = self.get_weights_path() + model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip()) + model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights) + model_handle.load_top_weights(weights_path) + return model_handle + + +class MlflowLoader(ModelLoader): + + def __init__(self, mlruns_dir): + self.__mlruns_dir = mlruns_dir + self.__model_handle = None + + def load_model(self, run_id): + if not self.__model_handle: + mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir) + model_handel = mlflow_reader.get_model_handle(BASE_WEIGHTS) + self.__model_handle = model_handel + + return self.__model_handle + + def load_classes(self, run_id): + model_handle = self.load_model(run_id) + + classes = model_handle.model.classes_ + classes_readable = np.array(model_handle.classes) + classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]] + + return classes_readable_aligned diff --git a/image_prediction/model_loader/loaders/mock.py b/image_prediction/model_loader/loaders/mock.py new file mode 100644 index 0000000..be9ff82 --- /dev/null +++ b/image_prediction/model_loader/loaders/mock.py @@ -0,0 +1,15 @@ +from image_prediction.model_loader.loader import ModelLoader + + +class ModelLoaderMock(ModelLoader): + + model = None + classes = None + + def load_model(self, identifier): + assert self.model is not None, "Set the model to be returned first via monkeypatching" + return self.model + + def load_classes(self, identifier): + assert self.classes is not None, "Set the classes to be returned first via monkeypatching" + return self.classes diff --git a/image_prediction/model_loading.py b/image_prediction/model_loading.py new file mode 100644 index 0000000..a3f0a32 --- /dev/null +++ b/image_prediction/model_loading.py @@ -0,0 +1,17 @@ +from collections import namedtuple + +from image_prediction.locations import MLRUNS_DIR +from image_prediction.model_loader.loader import ModelLoader +from image_prediction.model_loader.loaders.mlflow import MlflowLoader + +ModelClassesPair = namedtuple("ModelClassesPair", ["model", "classes"]) + + +def load_model_and_classes(identifier, model_loader: ModelLoader = None) -> ModelClassesPair: + if not model_loader: + model_loader = MlflowLoader(MLRUNS_DIR) + + model = model_loader.load_model(identifier) + classes = model_loader.load_classes(identifier) + + return ModelClassesPair(model, classes) diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index b904c76..b20f662 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -12,10 +12,11 @@ from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock -from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor +from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownModelLoader from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractors.mock import ImageExtractorMock from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor +from image_prediction.model_loader.loaders.mock import ModelLoaderMock @pytest.fixture @@ -201,3 +202,17 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair): with tempfile.NamedTemporaryFile(suffix=".png") as temp_image: image.save(temp_image.name) pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type="png") + + +@pytest.fixture +def model_loader(loader_type, monkeypatch, estimator_adapter, classes): + if loader_type == "mock": + loader = ModelLoaderMock() + else: + raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.") + + # estimator adapter just has suitable interface for test, but is not actually what the model loader what typically + # load. Rather the model loader would load sklearn or tensorflow models (or redai model handles). + monkeypatch.setattr(loader, "model", estimator_adapter) + monkeypatch.setattr(loader, "classes", classes) + return loader diff --git a/test/unit_tests/model_loader_test.py b/test/unit_tests/model_loader_test.py new file mode 100644 index 0000000..8d0420a --- /dev/null +++ b/test/unit_tests/model_loader_test.py @@ -0,0 +1,12 @@ +import pytest + +from image_prediction.model_loading import load_model_and_classes + + +@pytest.mark.parametrize("loader_type", ["mock"]) +@pytest.mark.parametrize("estimator_type", ["mock"]) +@pytest.mark.parametrize("batch_size", [3]) +def test_load_model_and_classes(model_loader, estimator_adapter, classes): + model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader) + assert model_loaded == estimator_adapter + assert classes_loaded == classes