refactoring
This commit is contained in:
parent
e8fb01b4b7
commit
9c9070e8bf
@ -12,15 +12,25 @@ logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def estimator():
|
||||
return EstimatorMock()
|
||||
def estimator(output_batch):
|
||||
estimator = EstimatorMock()
|
||||
estimator.output_batch = output_batch
|
||||
return estimator
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def batches(batch_size, classes):
|
||||
input_batch = np.random.normal(size=(batch_size, 10, 15))
|
||||
output_batch = np.random.randint(low=0, high=len(classes), size=batch_size)
|
||||
return input_batch, output_batch
|
||||
def input_batch(batch_size, classes):
|
||||
return np.random.normal(size=(batch_size, 10, 15))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def output_batch(batch_size, classes):
|
||||
return np.random.randint(low=0, high=len(classes), size=batch_size)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def expected_predictions(output_batch, classes):
|
||||
return map_labels(output_batch, classes)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -40,13 +50,6 @@ def service_estimator(model_type, estimator, classes):
|
||||
|
||||
@pytest.mark.parametrize("model_type", ["mock"], scope="session")
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
||||
def test_predict(service_estimator, batches, classes):
|
||||
|
||||
input_batch, output_batch = batches
|
||||
expected_predictions = map_labels(output_batch, classes)
|
||||
|
||||
service_estimator.estimator.output_batch = output_batch
|
||||
|
||||
def test_predict(service_estimator, input_batch, expected_predictions):
|
||||
predictions = service_estimator.predict(input_batch)
|
||||
|
||||
assert predictions == expected_predictions
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user