image-classification-service/test/unit_tests/service_estimator_test.py
2022-03-25 11:42:31 +01:00

42 lines
1.2 KiB
Python

import numpy as np
import pytest
from image_prediction.estimator.mock import EstimatorMock
from image_prediction.service_estimator.mock import ServiceEstimatorMock
@pytest.fixture(scope="session")
def estimator():
return EstimatorMock()
@pytest.fixture(scope="session")
def batches(batch_size, classes):
input_batch = np.random.normal(size=(batch_size, 10, 15))
output_batch = np.random.randint(low=0, high=len(classes), size=batch_size)
return input_batch, output_batch
@pytest.fixture(scope="session")
def classes():
return ["A", "B", "C"]
def map_labels(numeric_labels, classes):
return [classes[nl] for nl in numeric_labels]
@pytest.fixture(scope="session")
def service_estimator(model_type, estimator, classes):
if model_type == "mock":
return ServiceEstimatorMock(estimator, classes)
@pytest.mark.parametrize("model_type", ["mock"], scope="session")
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
def test_predict(service_estimator, batches, classes):
input_batch, output_batch = batches
service_estimator.estimator.output_batch = output_batch
expected_predictions = map_labels(output_batch, classes)
assert service_estimator.predict(input_batch) == expected_predictions