28 lines
745 B
Python
28 lines
745 B
Python
import logging
|
|
|
|
import pytest
|
|
|
|
from image_prediction.utils import get_logger
|
|
|
|
logger = get_logger()
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
|
|
@pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
|
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
|
def test_predict(estimator, input_batch, expected_predictions):
|
|
predictions = estimator.predict(input_batch)
|
|
assert predictions == expected_predictions
|
|
|
|
|
|
def test_batch_format(input_batch):
|
|
|
|
def channels_are_last(input_batch):
|
|
return input_batch.shape[-1] == 3
|
|
|
|
def is_fourth_order_tensor(input_batch):
|
|
return input_batch.ndim == 4
|
|
|
|
assert channels_are_last(input_batch)
|
|
assert is_fourth_order_tensor(input_batch)
|