62 lines
1.7 KiB
Python
62 lines
1.7 KiB
Python
import logging
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from image_prediction.estimator.estimators.mock import EstimatorMock
|
|
from image_prediction.estimator.adapter.patch import EstimatorAdapterPatch
|
|
from image_prediction.service_estimator.service_estimator import ServiceEstimator
|
|
from image_prediction.utils import get_logger
|
|
|
|
logger = get_logger()
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def estimator(estimator_type):
|
|
if estimator_type == "mock":
|
|
return EstimatorMock()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def estimator_adapter(output_batch, estimator):
|
|
estimator_adapter = EstimatorAdapterPatch(estimator)
|
|
estimator_adapter.output_batch = output_batch
|
|
return estimator_adapter
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def input_batch(batch_size, classes):
|
|
return np.random.normal(size=(batch_size, 10, 15))
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def output_batch(batch_size, classes):
|
|
return np.random.randint(low=0, high=len(classes), size=batch_size)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def expected_predictions(output_batch, classes):
|
|
return map_labels(output_batch, classes)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def classes():
|
|
return ["A", "B", "C"]
|
|
|
|
|
|
def map_labels(numeric_labels, classes):
|
|
return [classes[nl] for nl in numeric_labels]
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def service_estimator(estimator_adapter, classes):
|
|
return ServiceEstimator(estimator_adapter, classes)
|
|
|
|
|
|
@pytest.mark.parametrize("estimator_type", ["mock"], scope="session")
|
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
|
def test_predict(service_estimator, input_batch, expected_predictions):
|
|
predictions = service_estimator.predict(input_batch)
|
|
assert predictions == expected_predictions
|