added patched test for mlflow model loader
This commit is contained in:
parent
c06905625d
commit
3b4c2a40b2
@ -19,14 +19,14 @@ class Classifier:
|
||||
classes: mapping from a numerical label to a human-readable label for classes
|
||||
"""
|
||||
self.__estimator_adapter = estimator_adapter
|
||||
self.__classes = classes
|
||||
self._classes = classes
|
||||
|
||||
def predict(self, batch: np.array) -> List[str]:
|
||||
|
||||
if batch.shape[0] == 0:
|
||||
return []
|
||||
|
||||
return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
|
||||
return [self._classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
|
||||
|
||||
def __call__(self, batch: np.array) -> List[str]:
|
||||
return self.predict(batch)
|
||||
|
||||
@ -63,15 +63,15 @@ class MlflowLoader(ModelLoader):
|
||||
|
||||
def __init__(self, mlruns_dir):
|
||||
self.__mlruns_dir = mlruns_dir
|
||||
self.__model_handle = None
|
||||
self._model_handle = None
|
||||
|
||||
def load_model(self, run_id):
|
||||
if not self.__model_handle:
|
||||
if not self._model_handle:
|
||||
mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
|
||||
model_handel = mlflow_reader.get_model_handle(BASE_WEIGHTS)
|
||||
self.__model_handle = model_handel
|
||||
self._model_handle = model_handel
|
||||
|
||||
return self.__model_handle
|
||||
return self._model_handle
|
||||
|
||||
def load_classes(self, run_id):
|
||||
model_handle = self.load_model(run_id)
|
||||
|
||||
@ -16,6 +16,7 @@ from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExt
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
|
||||
from image_prediction.model_loader.loaders.mock import ModelLoaderMock
|
||||
|
||||
|
||||
@ -205,14 +206,28 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_loader(loader_type, monkeypatch, estimator_adapter, classes):
|
||||
def model_handle_mock(classes, classifier):
|
||||
|
||||
class ModelHandleMock:
|
||||
|
||||
def __init__(self, classes):
|
||||
classifier.classes_ = np.array(list(range(len(classes))))
|
||||
self.classes = classes
|
||||
self.model = classifier
|
||||
|
||||
return ModelHandleMock(classes)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_loader(loader_type, monkeypatch, model_handle_mock, classes):
|
||||
if loader_type == "mock":
|
||||
loader = ModelLoaderMock()
|
||||
monkeypatch.setattr(loader, "model", model_handle_mock)
|
||||
monkeypatch.setattr(loader, "classes", classes)
|
||||
elif loader_type == "mlflow":
|
||||
loader = MlflowLoader("...")
|
||||
monkeypatch.setattr(loader, "_model_handle", model_handle_mock)
|
||||
else:
|
||||
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
|
||||
|
||||
# estimator adapter just has suitable interface for test, but is not actually what the model loader what typically
|
||||
# load. Rather the model loader would load sklearn or tensorflow models (or redai model handles).
|
||||
monkeypatch.setattr(loader, "model", estimator_adapter)
|
||||
monkeypatch.setattr(loader, "classes", classes)
|
||||
return loader
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from image_prediction.model_loading import load_model_and_classes
|
||||
|
||||
|
||||
@pytest.mark.parametrize("loader_type", ["mock"])
|
||||
@pytest.mark.parametrize("loader_type", ["mock", "mlflow"])
|
||||
@pytest.mark.parametrize("estimator_type", ["mock"])
|
||||
@pytest.mark.parametrize("batch_size", [3])
|
||||
def test_load_model_and_classes(model_loader, estimator_adapter, classes):
|
||||
def test_load_model_and_classes(model_loader, model_handle_mock, classes):
|
||||
model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader)
|
||||
assert model_loaded == estimator_adapter
|
||||
assert classes_loaded == classes
|
||||
assert model_loaded == model_handle_mock
|
||||
assert np.all(classes_loaded == classes)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user