This commit is contained in:
Matthias Bisping 2022-03-26 19:27:37 +01:00
parent 373c619b0c
commit ea298dacfa
2 changed files with 7 additions and 7 deletions

View File

@ -10,18 +10,18 @@ from image_prediction.predictor.predictor import Predictor
@pytest.fixture @pytest.fixture
def predictor(service_estimator): def predictor(estimator):
return Predictor(service_estimator) return Predictor(estimator)
@pytest.fixture @pytest.fixture
def service_estimator(estimator, classes): def estimator(estimator_adapter, classes):
service_estimator = Estimator(estimator, classes) service_estimator = Estimator(estimator_adapter, classes)
return service_estimator return service_estimator
@pytest.fixture @pytest.fixture
def estimator(estimator_type, keras_model, output_batch, monkeypatch): def estimator_adapter(estimator_type, keras_model, output_batch, monkeypatch):
if estimator_type == "mock": if estimator_type == "mock":
estimator = EstimatorAdapterMock(DummyEstimator()) estimator = EstimatorAdapterMock(DummyEstimator())
elif estimator_type == "keras": elif estimator_type == "keras":

View File

@ -10,8 +10,8 @@ logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session") @pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session") @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
def test_predict(service_estimator, input_batch, expected_predictions): def test_predict(estimator, input_batch, expected_predictions):
predictions = service_estimator.predict(input_batch) predictions = estimator.predict(input_batch)
assert predictions == expected_predictions assert predictions == expected_predictions