This commit is contained in:
Matthias Bisping 2022-03-26 19:27:37 +01:00
parent 373c619b0c
commit ea298dacfa
2 changed files with 7 additions and 7 deletions

View File

@ -10,18 +10,18 @@ from image_prediction.predictor.predictor import Predictor
@pytest.fixture
def predictor(service_estimator):
return Predictor(service_estimator)
def predictor(estimator):
return Predictor(estimator)
@pytest.fixture
def service_estimator(estimator, classes):
service_estimator = Estimator(estimator, classes)
def estimator(estimator_adapter, classes):
service_estimator = Estimator(estimator_adapter, classes)
return service_estimator
@pytest.fixture
def estimator(estimator_type, keras_model, output_batch, monkeypatch):
def estimator_adapter(estimator_type, keras_model, output_batch, monkeypatch):
if estimator_type == "mock":
estimator = EstimatorAdapterMock(DummyEstimator())
elif estimator_type == "keras":

View File

@ -10,8 +10,8 @@ logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
def test_predict(service_estimator, input_batch, expected_predictions):
predictions = service_estimator.predict(input_batch)
def test_predict(estimator, input_batch, expected_predictions):
predictions = estimator.predict(input_batch)
assert predictions == expected_predictions