added patched test for mlflow model loader

This commit is contained in:
Matthias Bisping 2022-03-28 21:51:17 +02:00
parent c06905625d
commit 3b4c2a40b2
4 changed files with 31 additions and 15 deletions

View File

@ -19,14 +19,14 @@ class Classifier:
classes: mapping from a numerical label to a human-readable label for classes
"""
self.__estimator_adapter = estimator_adapter
self.__classes = classes
self._classes = classes
def predict(self, batch: np.array) -> List[str]:
if batch.shape[0] == 0:
return []
return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
return [self._classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
def __call__(self, batch: np.array) -> List[str]:
return self.predict(batch)

View File

@ -63,15 +63,15 @@ class MlflowLoader(ModelLoader):
def __init__(self, mlruns_dir):
self.__mlruns_dir = mlruns_dir
self.__model_handle = None
self._model_handle = None
def load_model(self, run_id):
if not self.__model_handle:
if not self._model_handle:
mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
model_handel = mlflow_reader.get_model_handle(BASE_WEIGHTS)
self.__model_handle = model_handel
self._model_handle = model_handel
return self.__model_handle
return self._model_handle
def load_classes(self, run_id):
model_handle = self.load_model(run_id)

View File

@ -16,6 +16,7 @@ from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExt
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
from image_prediction.model_loader.loaders.mock import ModelLoaderMock
@ -205,14 +206,28 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair):
@pytest.fixture
def model_loader(loader_type, monkeypatch, estimator_adapter, classes):
def model_handle_mock(classes, classifier):
class ModelHandleMock:
def __init__(self, classes):
classifier.classes_ = np.array(list(range(len(classes))))
self.classes = classes
self.model = classifier
return ModelHandleMock(classes)
@pytest.fixture
def model_loader(loader_type, monkeypatch, model_handle_mock, classes):
if loader_type == "mock":
loader = ModelLoaderMock()
monkeypatch.setattr(loader, "model", model_handle_mock)
monkeypatch.setattr(loader, "classes", classes)
elif loader_type == "mlflow":
loader = MlflowLoader("...")
monkeypatch.setattr(loader, "_model_handle", model_handle_mock)
else:
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
# estimator adapter just has suitable interface for test, but is not actually what the model loader what typically
# load. Rather the model loader would load sklearn or tensorflow models (or redai model handles).
monkeypatch.setattr(loader, "model", estimator_adapter)
monkeypatch.setattr(loader, "classes", classes)
return loader

View File

@ -1,12 +1,13 @@
import numpy as np
import pytest
from image_prediction.model_loading import load_model_and_classes
@pytest.mark.parametrize("loader_type", ["mock"])
@pytest.mark.parametrize("loader_type", ["mock", "mlflow"])
@pytest.mark.parametrize("estimator_type", ["mock"])
@pytest.mark.parametrize("batch_size", [3])
def test_load_model_and_classes(model_loader, estimator_adapter, classes):
def test_load_model_and_classes(model_loader, model_handle_mock, classes):
model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader)
assert model_loaded == estimator_adapter
assert classes_loaded == classes
assert model_loaded == model_handle_mock
assert np.all(classes_loaded == classes)