testing laberl mappers for raising of excpetions when encountering unexpected input formats
This commit is contained in:
parent
ce3d33955e
commit
258c1ab02d
@ -3,9 +3,8 @@ import abc
|
|||||||
|
|
||||||
class Preprocessor(abc.ABC):
|
class Preprocessor(abc.ABC):
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def preprocess(batch):
|
def preprocess(self, batch):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
|
|||||||
@ -6,10 +6,10 @@ ImageMetadataPair = namedtuple("ImageMetadataPair", ["image", "metadata"])
|
|||||||
|
|
||||||
|
|
||||||
class ImageExtractor(abc.ABC):
|
class ImageExtractor(abc.ABC):
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def extract(self, obj) -> Iterable[ImageMetadataPair]:
|
def extract(self, obj) -> Iterable[ImageMetadataPair]:
|
||||||
"""Extracts images from an object"""
|
raise NotImplementedError
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(self, obj):
|
def __call__(self, obj):
|
||||||
return self.extract(obj)
|
return self.extract(obj)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ class IndexMapper(LabelMapper):
|
|||||||
self.__labels = labels
|
self.__labels = labels
|
||||||
|
|
||||||
def __validate_index_label_format(self, index_label: int) -> None:
|
def __validate_index_label_format(self, index_label: int) -> None:
|
||||||
if not 0 <= index_label <= len(self.__labels):
|
if not 0 <= index_label < len(self.__labels):
|
||||||
raise UnexpectedLabelFormat(
|
raise UnexpectedLabelFormat(
|
||||||
f"Received index label '{index_label}' that has no associated string label."
|
f"Received index label '{index_label}' that has no associated string label."
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from image_prediction.exceptions import UnexpectedLabelFormat
|
||||||
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
||||||
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
||||||
|
|
||||||
@ -5,6 +8,8 @@ from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
|||||||
def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_string_labels, classes):
|
def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_string_labels, classes):
|
||||||
mapper = IndexMapper(classes)
|
mapper = IndexMapper(classes)
|
||||||
assert list(mapper(batch_of_expected_numeric_labels)) == batch_of_expected_string_labels
|
assert list(mapper(batch_of_expected_numeric_labels)) == batch_of_expected_string_labels
|
||||||
|
with pytest.raises(UnexpectedLabelFormat):
|
||||||
|
list(mapper([len(classes)]))
|
||||||
|
|
||||||
|
|
||||||
def test_array_label_mapper(
|
def test_array_label_mapper(
|
||||||
@ -12,3 +17,5 @@ def test_array_label_mapper(
|
|||||||
):
|
):
|
||||||
mapper = ProbabilityMapper(classes)
|
mapper = ProbabilityMapper(classes)
|
||||||
assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings
|
assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings
|
||||||
|
with pytest.raises(UnexpectedLabelFormat):
|
||||||
|
list(mapper([[0] * len(classes) + [1]]))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user