diff --git a/image_prediction/estimator/preprocessor/preprocessor.py b/image_prediction/estimator/preprocessor/preprocessor.py index 3959b51..6f39564 100644 --- a/image_prediction/estimator/preprocessor/preprocessor.py +++ b/image_prediction/estimator/preprocessor/preprocessor.py @@ -3,9 +3,8 @@ import abc class Preprocessor(abc.ABC): - @staticmethod @abc.abstractmethod - def preprocess(batch): + def preprocess(self, batch): raise NotImplementedError def __call__(self, batch): diff --git a/image_prediction/image_extractor/extractor.py b/image_prediction/image_extractor/extractor.py index 66bfc3b..fc318a8 100644 --- a/image_prediction/image_extractor/extractor.py +++ b/image_prediction/image_extractor/extractor.py @@ -6,10 +6,10 @@ ImageMetadataPair = namedtuple("ImageMetadataPair", ["image", "metadata"]) class ImageExtractor(abc.ABC): + @abc.abstractmethod def extract(self, obj) -> Iterable[ImageMetadataPair]: - """Extracts images from an object""" - pass + raise NotImplementedError def __call__(self, obj): return self.extract(obj) diff --git a/image_prediction/label_mapper/mappers/numeric.py b/image_prediction/label_mapper/mappers/numeric.py index 00e9f45..f08de75 100644 --- a/image_prediction/label_mapper/mappers/numeric.py +++ b/image_prediction/label_mapper/mappers/numeric.py @@ -9,7 +9,7 @@ class IndexMapper(LabelMapper): self.__labels = labels def __validate_index_label_format(self, index_label: int) -> None: - if not 0 <= index_label <= len(self.__labels): + if not 0 <= index_label < len(self.__labels): raise UnexpectedLabelFormat( f"Received index label '{index_label}' that has no associated string label." ) diff --git a/test/unit_tests/label_mapper_test.py b/test/unit_tests/label_mapper_test.py index 83de74a..23da392 100644 --- a/test/unit_tests/label_mapper_test.py +++ b/test/unit_tests/label_mapper_test.py @@ -1,3 +1,6 @@ +import pytest + +from image_prediction.exceptions import UnexpectedLabelFormat from image_prediction.label_mapper.mappers.numeric import IndexMapper from image_prediction.label_mapper.mappers.probability import ProbabilityMapper @@ -5,6 +8,8 @@ from image_prediction.label_mapper.mappers.probability import ProbabilityMapper def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_string_labels, classes): mapper = IndexMapper(classes) assert list(mapper(batch_of_expected_numeric_labels)) == batch_of_expected_string_labels + with pytest.raises(UnexpectedLabelFormat): + list(mapper([len(classes)])) def test_array_label_mapper( @@ -12,3 +17,5 @@ def test_array_label_mapper( ): mapper = ProbabilityMapper(classes) assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings + with pytest.raises(UnexpectedLabelFormat): + list(mapper([[0] * len(classes) + [1]]))