import pytest @pytest.mark.parametrize("estimator_type", ["mock", "keras", "redai"]) @pytest.mark.parametrize("label_format", ["index", "probability"]) def test_classifier(classifier, input_batch, expected_predictions_mapped): predictions = list(classifier(input_batch)) assert predictions == expected_predictions_mapped def test_batch_format(input_batch): def channels_are_last(input_batch): return input_batch.shape[-1] == 3 def is_fourth_order_tensor(input_batch): return input_batch.ndim == 4 assert channels_are_last(input_batch) assert is_fourth_order_tensor(input_batch)