diff --git a/image_prediction/classifier/classifier.py b/image_prediction/classifier/classifier.py index a9164c5..546c8f2 100644 --- a/image_prediction/classifier/classifier.py +++ b/image_prediction/classifier/classifier.py @@ -5,7 +5,7 @@ import numpy as np from PIL.Image import Image from image_prediction.estimator.adapter.adapter import EstimatorAdapter -from image_prediction.exceptions import UnexpectedPredictionFormat +from image_prediction.exceptions import UnexpectedLabelFormat from image_prediction.utils import get_logger logger = get_logger() @@ -25,13 +25,13 @@ class Classifier: def __validate_array_prediction_format(self, prediction): if not len(prediction) == len(self._classes): - raise UnexpectedPredictionFormat( + raise UnexpectedLabelFormat( f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)})." ) def __validate_int_prediction_format(self, prediction): if not 0 <= prediction <= len(self._classes): - raise UnexpectedPredictionFormat( + raise UnexpectedLabelFormat( f"Received class index '{prediction}' as prediction that has no associated class label." ) diff --git a/image_prediction/exceptions.py b/image_prediction/exceptions.py index 730286d..3b82070 100644 --- a/image_prediction/exceptions.py +++ b/image_prediction/exceptions.py @@ -14,7 +14,7 @@ class UnknownDatabaseType(ValueError): pass -class UnexpectedPredictionFormat(ValueError): +class UnexpectedLabelFormat(ValueError): pass diff --git a/image_prediction/label_mapper/__init__.py b/image_prediction/label_mapper/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/label_mapper/mapper.py b/image_prediction/label_mapper/mapper.py new file mode 100644 index 0000000..663de94 --- /dev/null +++ b/image_prediction/label_mapper/mapper.py @@ -0,0 +1,11 @@ +import abc + + +class LabelMapper(abc.ABC): + + @abc.abstractmethod + def map_labels(self, items): + raise NotImplementedError + + def __call__(self, items): + return self.map_labels(items) diff --git a/image_prediction/label_mapper/mappers/__init__.py b/image_prediction/label_mapper/mappers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/label_mapper/mappers/numeric.py b/image_prediction/label_mapper/mappers/numeric.py new file mode 100644 index 0000000..d166186 --- /dev/null +++ b/image_prediction/label_mapper/mappers/numeric.py @@ -0,0 +1,22 @@ +from typing import Mapping, Iterable + +from image_prediction.exceptions import UnexpectedLabelFormat +from image_prediction.label_mapper.mapper import LabelMapper + + +class IndexLabelMapper(LabelMapper): + def __init__(self, labels: Mapping[int, str]): + self.__labels = labels + + def __validate_int_prediction_format(self, index_label: int) -> None: + if not 0 <= index_label <= len(self.__labels): + raise UnexpectedLabelFormat( + f"Received index label '{index_label}' that has no associated string label." + ) + + def __map_label(self, index_label: int) -> str: + self.__validate_int_prediction_format(index_label) + return self.__labels[index_label] + + def map_labels(self, index_labels: Iterable[int]) -> Iterable[str]: + return map(self.__map_label, index_labels) diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index 8500b8e..85d560d 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -134,8 +134,8 @@ def expected_batch_string_labels(expected_batch_numeric_labels, classes): @pytest.fixture -def expected_batch_numeric_labels(input_batch, classes): - return random.choices(range(len(classes)), k=len(input_batch)) +def expected_batch_numeric_labels(batch_size, classes): + return random.choices(range(len(classes)), k=batch_size) @pytest.fixture diff --git a/test/unit_tests/label_mapper_test.py b/test/unit_tests/label_mapper_test.py new file mode 100644 index 0000000..06e48b9 --- /dev/null +++ b/test/unit_tests/label_mapper_test.py @@ -0,0 +1,6 @@ +from image_prediction.label_mapper.mappers.numeric import IndexLabelMapper + + +def test_index_label_mapper(expected_batch_numeric_labels, expected_batch_string_labels, classes): + mapper = IndexLabelMapper(classes) + assert list(mapper(expected_batch_numeric_labels)) == expected_batch_string_labels