14 lines
535 B
Python
14 lines
535 B
Python
import numpy as np
|
|
import pytest
|
|
|
|
from image_prediction.model_loading import load_model_and_classes
|
|
|
|
|
|
@pytest.mark.parametrize("loader_type", ["mock", "mlflow"])
|
|
@pytest.mark.parametrize("estimator_type", ["mock"])
|
|
@pytest.mark.parametrize("batch_size", [3])
|
|
def test_load_model_and_classes(model_loader, model_handle_mock, classes):
|
|
model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader)
|
|
assert model_loaded == model_handle_mock
|
|
assert np.all(classes_loaded == classes)
|