refactoring

This commit is contained in:
Matthias Bisping 2022-03-25 12:24:23 +01:00
parent e8fb01b4b7
commit 9c9070e8bf

View File

@ -12,15 +12,25 @@ logger.setLevel(logging.DEBUG)
@pytest.fixture(scope="session")
def estimator():
return EstimatorMock()
def estimator(output_batch):
estimator = EstimatorMock()
estimator.output_batch = output_batch
return estimator
@pytest.fixture(scope="session")
def batches(batch_size, classes):
input_batch = np.random.normal(size=(batch_size, 10, 15))
output_batch = np.random.randint(low=0, high=len(classes), size=batch_size)
return input_batch, output_batch
def input_batch(batch_size, classes):
return np.random.normal(size=(batch_size, 10, 15))
@pytest.fixture(scope="session")
def output_batch(batch_size, classes):
return np.random.randint(low=0, high=len(classes), size=batch_size)
@pytest.fixture(scope="session")
def expected_predictions(output_batch, classes):
return map_labels(output_batch, classes)
@pytest.fixture(scope="session")
@ -40,13 +50,6 @@ def service_estimator(model_type, estimator, classes):
@pytest.mark.parametrize("model_type", ["mock"], scope="session")
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
def test_predict(service_estimator, batches, classes):
input_batch, output_batch = batches
expected_predictions = map_labels(output_batch, classes)
service_estimator.estimator.output_batch = output_batch
def test_predict(service_estimator, input_batch, expected_predictions):
predictions = service_estimator.predict(input_batch)
assert predictions == expected_predictions