added patched test for mlflow model loader
This commit is contained in:
parent
c06905625d
commit
3b4c2a40b2
@ -19,14 +19,14 @@ class Classifier:
|
|||||||
classes: mapping from a numerical label to a human-readable label for classes
|
classes: mapping from a numerical label to a human-readable label for classes
|
||||||
"""
|
"""
|
||||||
self.__estimator_adapter = estimator_adapter
|
self.__estimator_adapter = estimator_adapter
|
||||||
self.__classes = classes
|
self._classes = classes
|
||||||
|
|
||||||
def predict(self, batch: np.array) -> List[str]:
|
def predict(self, batch: np.array) -> List[str]:
|
||||||
|
|
||||||
if batch.shape[0] == 0:
|
if batch.shape[0] == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
|
return [self._classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
|
||||||
|
|
||||||
def __call__(self, batch: np.array) -> List[str]:
|
def __call__(self, batch: np.array) -> List[str]:
|
||||||
return self.predict(batch)
|
return self.predict(batch)
|
||||||
|
|||||||
@ -63,15 +63,15 @@ class MlflowLoader(ModelLoader):
|
|||||||
|
|
||||||
def __init__(self, mlruns_dir):
|
def __init__(self, mlruns_dir):
|
||||||
self.__mlruns_dir = mlruns_dir
|
self.__mlruns_dir = mlruns_dir
|
||||||
self.__model_handle = None
|
self._model_handle = None
|
||||||
|
|
||||||
def load_model(self, run_id):
|
def load_model(self, run_id):
|
||||||
if not self.__model_handle:
|
if not self._model_handle:
|
||||||
mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
|
mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
|
||||||
model_handel = mlflow_reader.get_model_handle(BASE_WEIGHTS)
|
model_handel = mlflow_reader.get_model_handle(BASE_WEIGHTS)
|
||||||
self.__model_handle = model_handel
|
self._model_handle = model_handel
|
||||||
|
|
||||||
return self.__model_handle
|
return self._model_handle
|
||||||
|
|
||||||
def load_classes(self, run_id):
|
def load_classes(self, run_id):
|
||||||
model_handle = self.load_model(run_id)
|
model_handle = self.load_model(run_id)
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExt
|
|||||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
||||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||||
|
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
|
||||||
from image_prediction.model_loader.loaders.mock import ModelLoaderMock
|
from image_prediction.model_loader.loaders.mock import ModelLoaderMock
|
||||||
|
|
||||||
|
|
||||||
@ -205,14 +206,28 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model_loader(loader_type, monkeypatch, estimator_adapter, classes):
|
def model_handle_mock(classes, classifier):
|
||||||
|
|
||||||
|
class ModelHandleMock:
|
||||||
|
|
||||||
|
def __init__(self, classes):
|
||||||
|
classifier.classes_ = np.array(list(range(len(classes))))
|
||||||
|
self.classes = classes
|
||||||
|
self.model = classifier
|
||||||
|
|
||||||
|
return ModelHandleMock(classes)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_loader(loader_type, monkeypatch, model_handle_mock, classes):
|
||||||
if loader_type == "mock":
|
if loader_type == "mock":
|
||||||
loader = ModelLoaderMock()
|
loader = ModelLoaderMock()
|
||||||
|
monkeypatch.setattr(loader, "model", model_handle_mock)
|
||||||
|
monkeypatch.setattr(loader, "classes", classes)
|
||||||
|
elif loader_type == "mlflow":
|
||||||
|
loader = MlflowLoader("...")
|
||||||
|
monkeypatch.setattr(loader, "_model_handle", model_handle_mock)
|
||||||
else:
|
else:
|
||||||
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
|
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
|
||||||
|
|
||||||
# estimator adapter just has suitable interface for test, but is not actually what the model loader what typically
|
|
||||||
# load. Rather the model loader would load sklearn or tensorflow models (or redai model handles).
|
|
||||||
monkeypatch.setattr(loader, "model", estimator_adapter)
|
|
||||||
monkeypatch.setattr(loader, "classes", classes)
|
|
||||||
return loader
|
return loader
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from image_prediction.model_loading import load_model_and_classes
|
from image_prediction.model_loading import load_model_and_classes
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("loader_type", ["mock"])
|
@pytest.mark.parametrize("loader_type", ["mock", "mlflow"])
|
||||||
@pytest.mark.parametrize("estimator_type", ["mock"])
|
@pytest.mark.parametrize("estimator_type", ["mock"])
|
||||||
@pytest.mark.parametrize("batch_size", [3])
|
@pytest.mark.parametrize("batch_size", [3])
|
||||||
def test_load_model_and_classes(model_loader, estimator_adapter, classes):
|
def test_load_model_and_classes(model_loader, model_handle_mock, classes):
|
||||||
model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader)
|
model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader)
|
||||||
assert model_loaded == estimator_adapter
|
assert model_loaded == model_handle_mock
|
||||||
assert classes_loaded == classes
|
assert np.all(classes_loaded == classes)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user