testing laberl mappers for raising of excpetions when encountering unexpected input formats

This commit is contained in:
Matthias Bisping 2022-03-30 18:15:45 +02:00
parent ce3d33955e
commit 258c1ab02d
4 changed files with 11 additions and 5 deletions

View File

@ -3,9 +3,8 @@ import abc
class Preprocessor(abc.ABC): class Preprocessor(abc.ABC):
@staticmethod
@abc.abstractmethod @abc.abstractmethod
def preprocess(batch): def preprocess(self, batch):
raise NotImplementedError raise NotImplementedError
def __call__(self, batch): def __call__(self, batch):

View File

@ -6,10 +6,10 @@ ImageMetadataPair = namedtuple("ImageMetadataPair", ["image", "metadata"])
class ImageExtractor(abc.ABC): class ImageExtractor(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def extract(self, obj) -> Iterable[ImageMetadataPair]: def extract(self, obj) -> Iterable[ImageMetadataPair]:
"""Extracts images from an object""" raise NotImplementedError
pass
def __call__(self, obj): def __call__(self, obj):
return self.extract(obj) return self.extract(obj)

View File

@ -9,7 +9,7 @@ class IndexMapper(LabelMapper):
self.__labels = labels self.__labels = labels
def __validate_index_label_format(self, index_label: int) -> None: def __validate_index_label_format(self, index_label: int) -> None:
if not 0 <= index_label <= len(self.__labels): if not 0 <= index_label < len(self.__labels):
raise UnexpectedLabelFormat( raise UnexpectedLabelFormat(
f"Received index label '{index_label}' that has no associated string label." f"Received index label '{index_label}' that has no associated string label."
) )

View File

@ -1,3 +1,6 @@
import pytest
from image_prediction.exceptions import UnexpectedLabelFormat
from image_prediction.label_mapper.mappers.numeric import IndexMapper from image_prediction.label_mapper.mappers.numeric import IndexMapper
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
@ -5,6 +8,8 @@ from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_string_labels, classes): def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_string_labels, classes):
mapper = IndexMapper(classes) mapper = IndexMapper(classes)
assert list(mapper(batch_of_expected_numeric_labels)) == batch_of_expected_string_labels assert list(mapper(batch_of_expected_numeric_labels)) == batch_of_expected_string_labels
with pytest.raises(UnexpectedLabelFormat):
list(mapper([len(classes)]))
def test_array_label_mapper( def test_array_label_mapper(
@ -12,3 +17,5 @@ def test_array_label_mapper(
): ):
mapper = ProbabilityMapper(classes) mapper = ProbabilityMapper(classes)
assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings
with pytest.raises(UnexpectedLabelFormat):
list(mapper([[0] * len(classes) + [1]]))