20 lines
624 B
Python
20 lines
624 B
Python
import pytest
|
|
|
|
|
|
@pytest.mark.parametrize("estimator_type", ["mock", "keras", "redai"])
|
|
@pytest.mark.parametrize("label_format", ["index", "probability"])
|
|
def test_classifier(classifier, input_batch, expected_predictions_mapped):
|
|
predictions = list(classifier(input_batch))
|
|
assert predictions == expected_predictions_mapped
|
|
|
|
|
|
def test_batch_format(input_batch):
|
|
def channels_are_last(input_batch):
|
|
return input_batch.shape[-1] == 3
|
|
|
|
def is_fourth_order_tensor(input_batch):
|
|
return input_batch.ndim == 4
|
|
|
|
assert channels_are_last(input_batch)
|
|
assert is_fourth_order_tensor(input_batch)
|