Matthias Bisping 9d58ae714f renaming
2022-03-27 17:55:01 +02:00

28 lines
774 B
Python

import logging
import pytest
from image_prediction.utils import get_logger
logger = get_logger()
logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
def test_predict(classifier, input_batch, expected_predictions):
predictions = classifier.predict(input_batch)
assert predictions == expected_predictions
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
def test_batch_format(input_batch):
def channels_are_last(input_batch):
return input_batch.shape[-1] == 3
def is_fourth_order_tensor(input_batch):
return input_batch.ndim == 4
assert channels_are_last(input_batch)
assert is_fourth_order_tensor(input_batch)