import logging import pytest from image_prediction.utils import get_logger logger = get_logger() logger.setLevel(logging.DEBUG) @pytest.mark.parametrize("estimator_type", ["mock", "keras"]) @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64]) def test_predict(classifier, input_batch, expected_predictions): predictions = classifier.predict(input_batch) assert predictions == expected_predictions @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64]) def test_batch_format(input_batch): def channels_are_last(input_batch): return input_batch.shape[-1] == 3 def is_fourth_order_tensor(input_batch): return input_batch.ndim == 4 assert channels_are_last(input_batch) assert is_fourth_order_tensor(input_batch)