import pytest from image_prediction.exceptions import UnexpectedLabelFormat from image_prediction.label_mapper.mappers.numeric import IndexMapper from image_prediction.label_mapper.mappers.probability import ProbabilityMapper def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_string_labels, classes): mapper = IndexMapper(classes) assert list(mapper(batch_of_expected_numeric_labels)) == batch_of_expected_string_labels with pytest.raises(UnexpectedLabelFormat): list(mapper([len(classes)])) def test_array_label_mapper( batch_of_expected_probability_arrays, batch_of_expected_label_to_probability_mappings, classes ): mapper = ProbabilityMapper(classes) assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings with pytest.raises(UnexpectedLabelFormat): list(mapper([[0] * len(classes) + [1]]))