formatting
This commit is contained in:
parent
41f0cc8a41
commit
e8fb01b4b7
@ -1,8 +1,14 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from image_prediction.estimator.mock import EstimatorMock
|
from image_prediction.estimator.mock import EstimatorMock
|
||||||
from image_prediction.service_estimator.mock import ServiceEstimatorMock
|
from image_prediction.service_estimator.mock import ServiceEstimatorMock
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -35,7 +41,12 @@ def service_estimator(model_type, estimator, classes):
|
|||||||
@pytest.mark.parametrize("model_type", ["mock"], scope="session")
|
@pytest.mark.parametrize("model_type", ["mock"], scope="session")
|
||||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
||||||
def test_predict(service_estimator, batches, classes):
|
def test_predict(service_estimator, batches, classes):
|
||||||
|
|
||||||
input_batch, output_batch = batches
|
input_batch, output_batch = batches
|
||||||
service_estimator.estimator.output_batch = output_batch
|
|
||||||
expected_predictions = map_labels(output_batch, classes)
|
expected_predictions = map_labels(output_batch, classes)
|
||||||
assert service_estimator.predict(input_batch) == expected_predictions
|
|
||||||
|
service_estimator.estimator.output_batch = output_batch
|
||||||
|
|
||||||
|
predictions = service_estimator.predict(input_batch)
|
||||||
|
|
||||||
|
assert predictions == expected_predictions
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user