diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index 460bf83..0525abc 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -10,18 +10,18 @@ from image_prediction.predictor.predictor import Predictor @pytest.fixture -def predictor(service_estimator): - return Predictor(service_estimator) +def predictor(estimator): + return Predictor(estimator) @pytest.fixture -def service_estimator(estimator, classes): - service_estimator = Estimator(estimator, classes) +def estimator(estimator_adapter, classes): + service_estimator = Estimator(estimator_adapter, classes) return service_estimator @pytest.fixture -def estimator(estimator_type, keras_model, output_batch, monkeypatch): +def estimator_adapter(estimator_type, keras_model, output_batch, monkeypatch): if estimator_type == "mock": estimator = EstimatorAdapterMock(DummyEstimator()) elif estimator_type == "keras": diff --git a/test/unit_tests/service_estimator_test.py b/test/unit_tests/estimator_test.py similarity index 83% rename from test/unit_tests/service_estimator_test.py rename to test/unit_tests/estimator_test.py index 347e803..b0eb33a 100644 --- a/test/unit_tests/service_estimator_test.py +++ b/test/unit_tests/estimator_test.py @@ -10,8 +10,8 @@ logger.setLevel(logging.DEBUG) @pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session") @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session") -def test_predict(service_estimator, input_batch, expected_predictions): - predictions = service_estimator.predict(input_batch) +def test_predict(estimator, input_batch, expected_predictions): + predictions = estimator.predict(input_batch) assert predictions == expected_predictions