diff --git a/test/unit_tests/service_estimator_test.py b/test/unit_tests/service_estimator_test.py index b0ffc1c..42358fe 100644 --- a/test/unit_tests/service_estimator_test.py +++ b/test/unit_tests/service_estimator_test.py @@ -1,8 +1,14 @@ +import logging + import numpy as np import pytest from image_prediction.estimator.mock import EstimatorMock from image_prediction.service_estimator.mock import ServiceEstimatorMock +from image_prediction.utils import get_logger + +logger = get_logger() +logger.setLevel(logging.DEBUG) @pytest.fixture(scope="session") @@ -35,7 +41,12 @@ def service_estimator(model_type, estimator, classes): @pytest.mark.parametrize("model_type", ["mock"], scope="session") @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session") def test_predict(service_estimator, batches, classes): + input_batch, output_batch = batches - service_estimator.estimator.output_batch = output_batch expected_predictions = map_labels(output_batch, classes) - assert service_estimator.predict(input_batch) == expected_predictions + + service_estimator.estimator.output_batch = output_batch + + predictions = service_estimator.predict(input_batch) + + assert predictions == expected_predictions