22 lines
931 B
Python
22 lines
931 B
Python
import pytest
|
|
|
|
from image_prediction.exceptions import UnexpectedLabelFormat
|
|
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
|
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
|
|
|
|
|
def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_string_labels, classes):
|
|
mapper = IndexMapper(classes)
|
|
assert list(mapper(batch_of_expected_numeric_labels)) == batch_of_expected_string_labels
|
|
with pytest.raises(UnexpectedLabelFormat):
|
|
list(mapper([len(classes)]))
|
|
|
|
|
|
def test_array_label_mapper(
|
|
batch_of_expected_probability_arrays, batch_of_expected_label_to_probability_mappings, classes
|
|
):
|
|
mapper = ProbabilityMapper(classes)
|
|
assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings
|
|
with pytest.raises(UnexpectedLabelFormat):
|
|
list(mapper([[0] * len(classes) + [1]]))
|