diff --git a/test/unit_tests/service_estimator_test.py b/test/unit_tests/service_estimator_test.py index 42358fe..3609505 100644 --- a/test/unit_tests/service_estimator_test.py +++ b/test/unit_tests/service_estimator_test.py @@ -12,15 +12,25 @@ logger.setLevel(logging.DEBUG) @pytest.fixture(scope="session") -def estimator(): - return EstimatorMock() +def estimator(output_batch): + estimator = EstimatorMock() + estimator.output_batch = output_batch + return estimator @pytest.fixture(scope="session") -def batches(batch_size, classes): - input_batch = np.random.normal(size=(batch_size, 10, 15)) - output_batch = np.random.randint(low=0, high=len(classes), size=batch_size) - return input_batch, output_batch +def input_batch(batch_size, classes): + return np.random.normal(size=(batch_size, 10, 15)) + + +@pytest.fixture(scope="session") +def output_batch(batch_size, classes): + return np.random.randint(low=0, high=len(classes), size=batch_size) + + +@pytest.fixture(scope="session") +def expected_predictions(output_batch, classes): + return map_labels(output_batch, classes) @pytest.fixture(scope="session") @@ -40,13 +50,6 @@ def service_estimator(model_type, estimator, classes): @pytest.mark.parametrize("model_type", ["mock"], scope="session") @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session") -def test_predict(service_estimator, batches, classes): - - input_batch, output_batch = batches - expected_predictions = map_labels(output_batch, classes) - - service_estimator.estimator.output_batch = output_batch - +def test_predict(service_estimator, input_batch, expected_predictions): predictions = service_estimator.predict(input_batch) - assert predictions == expected_predictions