20 lines
618 B
Python

import pytest
@pytest.mark.parametrize("estimator_type", ["mock", "keras", "redai"])
@pytest.mark.parametrize("label_format", ["index", "probability"])
def test_classifier(classifier, input_batch, expected_predictions_mapped):
predictions = classifier(input_batch)
assert predictions == expected_predictions_mapped
def test_batch_format(input_batch):
def channels_are_last(input_batch):
return input_batch.shape[-1] == 3
def is_fourth_order_tensor(input_batch):
return input_batch.ndim == 4
assert channels_are_last(input_batch)
assert is_fourth_order_tensor(input_batch)