From 3b4c2a40b2674b95d74e83a9339caf81697e6bd3 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Mon, 28 Mar 2022 21:51:17 +0200 Subject: [PATCH] added patched test for mlflow model loader --- image_prediction/classifier/classifier.py | 4 +-- .../model_loader/loaders/mlflow.py | 8 +++--- test/unit_tests/conftest.py | 25 +++++++++++++++---- test/unit_tests/model_loader_test.py | 9 ++++--- 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/image_prediction/classifier/classifier.py b/image_prediction/classifier/classifier.py index 53f474f..f6dc3c6 100644 --- a/image_prediction/classifier/classifier.py +++ b/image_prediction/classifier/classifier.py @@ -19,14 +19,14 @@ class Classifier: classes: mapping from a numerical label to a human-readable label for classes """ self.__estimator_adapter = estimator_adapter - self.__classes = classes + self._classes = classes def predict(self, batch: np.array) -> List[str]: if batch.shape[0] == 0: return [] - return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)] + return [self._classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)] def __call__(self, batch: np.array) -> List[str]: return self.predict(batch) diff --git a/image_prediction/model_loader/loaders/mlflow.py b/image_prediction/model_loader/loaders/mlflow.py index 8495e30..dcbce5e 100644 --- a/image_prediction/model_loader/loaders/mlflow.py +++ b/image_prediction/model_loader/loaders/mlflow.py @@ -63,15 +63,15 @@ class MlflowLoader(ModelLoader): def __init__(self, mlruns_dir): self.__mlruns_dir = mlruns_dir - self.__model_handle = None + self._model_handle = None def load_model(self, run_id): - if not self.__model_handle: + if not self._model_handle: mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir) model_handel = mlflow_reader.get_model_handle(BASE_WEIGHTS) - self.__model_handle = model_handel + self._model_handle = model_handel - return self.__model_handle + return self._model_handle def load_classes(self, run_id): model_handle = self.load_model(run_id) diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index b20f662..4b281d8 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -16,6 +16,7 @@ from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExt from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractors.mock import ImageExtractorMock from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor +from image_prediction.model_loader.loaders.mlflow import MlflowLoader from image_prediction.model_loader.loaders.mock import ModelLoaderMock @@ -205,14 +206,28 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair): @pytest.fixture -def model_loader(loader_type, monkeypatch, estimator_adapter, classes): +def model_handle_mock(classes, classifier): + + class ModelHandleMock: + + def __init__(self, classes): + classifier.classes_ = np.array(list(range(len(classes)))) + self.classes = classes + self.model = classifier + + return ModelHandleMock(classes) + + +@pytest.fixture +def model_loader(loader_type, monkeypatch, model_handle_mock, classes): if loader_type == "mock": loader = ModelLoaderMock() + monkeypatch.setattr(loader, "model", model_handle_mock) + monkeypatch.setattr(loader, "classes", classes) + elif loader_type == "mlflow": + loader = MlflowLoader("...") + monkeypatch.setattr(loader, "_model_handle", model_handle_mock) else: raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.") - # estimator adapter just has suitable interface for test, but is not actually what the model loader what typically - # load. Rather the model loader would load sklearn or tensorflow models (or redai model handles). - monkeypatch.setattr(loader, "model", estimator_adapter) - monkeypatch.setattr(loader, "classes", classes) return loader diff --git a/test/unit_tests/model_loader_test.py b/test/unit_tests/model_loader_test.py index 8d0420a..3139c7a 100644 --- a/test/unit_tests/model_loader_test.py +++ b/test/unit_tests/model_loader_test.py @@ -1,12 +1,13 @@ +import numpy as np import pytest from image_prediction.model_loading import load_model_and_classes -@pytest.mark.parametrize("loader_type", ["mock"]) +@pytest.mark.parametrize("loader_type", ["mock", "mlflow"]) @pytest.mark.parametrize("estimator_type", ["mock"]) @pytest.mark.parametrize("batch_size", [3]) -def test_load_model_and_classes(model_loader, estimator_adapter, classes): +def test_load_model_and_classes(model_loader, model_handle_mock, classes): model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader) - assert model_loaded == estimator_adapter - assert classes_loaded == classes + assert model_loaded == model_handle_mock + assert np.all(classes_loaded == classes)