testing laberl mappers for raising of excpetions when encountering unexpected input formats
This commit is contained in:
parent
ce3d33955e
commit
258c1ab02d
@ -3,9 +3,8 @@ import abc
|
||||
|
||||
class Preprocessor(abc.ABC):
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def preprocess(batch):
|
||||
def preprocess(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, batch):
|
||||
|
||||
@ -6,10 +6,10 @@ ImageMetadataPair = namedtuple("ImageMetadataPair", ["image", "metadata"])
|
||||
|
||||
|
||||
class ImageExtractor(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def extract(self, obj) -> Iterable[ImageMetadataPair]:
|
||||
"""Extracts images from an object"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, obj):
|
||||
return self.extract(obj)
|
||||
|
||||
@ -9,7 +9,7 @@ class IndexMapper(LabelMapper):
|
||||
self.__labels = labels
|
||||
|
||||
def __validate_index_label_format(self, index_label: int) -> None:
|
||||
if not 0 <= index_label <= len(self.__labels):
|
||||
if not 0 <= index_label < len(self.__labels):
|
||||
raise UnexpectedLabelFormat(
|
||||
f"Received index label '{index_label}' that has no associated string label."
|
||||
)
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from image_prediction.exceptions import UnexpectedLabelFormat
|
||||
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
||||
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
||||
|
||||
@ -5,6 +8,8 @@ from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
||||
def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_string_labels, classes):
|
||||
mapper = IndexMapper(classes)
|
||||
assert list(mapper(batch_of_expected_numeric_labels)) == batch_of_expected_string_labels
|
||||
with pytest.raises(UnexpectedLabelFormat):
|
||||
list(mapper([len(classes)]))
|
||||
|
||||
|
||||
def test_array_label_mapper(
|
||||
@ -12,3 +17,5 @@ def test_array_label_mapper(
|
||||
):
|
||||
mapper = ProbabilityMapper(classes)
|
||||
assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings
|
||||
with pytest.raises(UnexpectedLabelFormat):
|
||||
list(mapper([[0] * len(classes) + [1]]))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user