renaming
This commit is contained in:
parent
373c619b0c
commit
ea298dacfa
@ -10,18 +10,18 @@ from image_prediction.predictor.predictor import Predictor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def predictor(service_estimator):
|
||||
return Predictor(service_estimator)
|
||||
def predictor(estimator):
|
||||
return Predictor(estimator)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service_estimator(estimator, classes):
|
||||
service_estimator = Estimator(estimator, classes)
|
||||
def estimator(estimator_adapter, classes):
|
||||
service_estimator = Estimator(estimator_adapter, classes)
|
||||
return service_estimator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def estimator(estimator_type, keras_model, output_batch, monkeypatch):
|
||||
def estimator_adapter(estimator_type, keras_model, output_batch, monkeypatch):
|
||||
if estimator_type == "mock":
|
||||
estimator = EstimatorAdapterMock(DummyEstimator())
|
||||
elif estimator_type == "keras":
|
||||
|
||||
@ -10,8 +10,8 @@ logger.setLevel(logging.DEBUG)
|
||||
|
||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
||||
def test_predict(service_estimator, input_batch, expected_predictions):
|
||||
predictions = service_estimator.predict(input_batch)
|
||||
def test_predict(estimator, input_batch, expected_predictions):
|
||||
predictions = estimator.predict(input_batch)
|
||||
assert predictions == expected_predictions
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user