testing laberl mappers for raising of excpetions when encountering unexpected input formats

This commit is contained in:
Matthias Bisping 2022-03-30 18:15:45 +02:00
parent ce3d33955e
commit 258c1ab02d
4 changed files with 11 additions and 5 deletions

View File

@ -3,9 +3,8 @@ import abc
class Preprocessor(abc.ABC):
@staticmethod
@abc.abstractmethod
def preprocess(batch):
def preprocess(self, batch):
raise NotImplementedError
def __call__(self, batch):

View File

@ -6,10 +6,10 @@ ImageMetadataPair = namedtuple("ImageMetadataPair", ["image", "metadata"])
class ImageExtractor(abc.ABC):
@abc.abstractmethod
def extract(self, obj) -> Iterable[ImageMetadataPair]:
"""Extracts images from an object"""
pass
raise NotImplementedError
def __call__(self, obj):
return self.extract(obj)

View File

@ -9,7 +9,7 @@ class IndexMapper(LabelMapper):
self.__labels = labels
def __validate_index_label_format(self, index_label: int) -> None:
if not 0 <= index_label <= len(self.__labels):
if not 0 <= index_label < len(self.__labels):
raise UnexpectedLabelFormat(
f"Received index label '{index_label}' that has no associated string label."
)

View File

@ -1,3 +1,6 @@
import pytest
from image_prediction.exceptions import UnexpectedLabelFormat
from image_prediction.label_mapper.mappers.numeric import IndexMapper
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
@ -5,6 +8,8 @@ from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_string_labels, classes):
mapper = IndexMapper(classes)
assert list(mapper(batch_of_expected_numeric_labels)) == batch_of_expected_string_labels
with pytest.raises(UnexpectedLabelFormat):
list(mapper([len(classes)]))
def test_array_label_mapper(
@ -12,3 +17,5 @@ def test_array_label_mapper(
):
mapper = ProbabilityMapper(classes)
assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings
with pytest.raises(UnexpectedLabelFormat):
list(mapper([[0] * len(classes) + [1]]))