Matthias Bisping ea298dacfa renaming
2022-03-26 19:27:37 +01:00

28 lines
745 B
Python

import logging
import pytest
from image_prediction.utils import get_logger
logger = get_logger()
logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
def test_predict(estimator, input_batch, expected_predictions):
predictions = estimator.predict(input_batch)
assert predictions == expected_predictions
def test_batch_format(input_batch):
def channels_are_last(input_batch):
return input_batch.shape[-1] == 3
def is_fourth_order_tensor(input_batch):
return input_batch.ndim == 4
assert channels_are_last(input_batch)
assert is_fourth_order_tensor(input_batch)