2022-03-28 21:51:21 +02:00

14 lines
535 B
Python

import numpy as np
import pytest
from image_prediction.model_loading import load_model_and_classes
@pytest.mark.parametrize("loader_type", ["mock", "mlflow"])
@pytest.mark.parametrize("estimator_type", ["mock"])
@pytest.mark.parametrize("batch_size", [3])
def test_load_model_and_classes(model_loader, model_handle_mock, classes):
model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader)
assert model_loaded == model_handle_mock
assert np.all(classes_loaded == classes)