added label mapper

This commit is contained in:
Matthias Bisping 2022-03-30 14:17:58 +02:00
parent 99d8e921db
commit 8c7e3e29f5
8 changed files with 45 additions and 6 deletions

View File

@ -5,7 +5,7 @@ import numpy as np
from PIL.Image import Image
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.exceptions import UnexpectedPredictionFormat
from image_prediction.exceptions import UnexpectedLabelFormat
from image_prediction.utils import get_logger
logger = get_logger()
@ -25,13 +25,13 @@ class Classifier:
def __validate_array_prediction_format(self, prediction):
if not len(prediction) == len(self._classes):
raise UnexpectedPredictionFormat(
raise UnexpectedLabelFormat(
f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)})."
)
def __validate_int_prediction_format(self, prediction):
if not 0 <= prediction <= len(self._classes):
raise UnexpectedPredictionFormat(
raise UnexpectedLabelFormat(
f"Received class index '{prediction}' as prediction that has no associated class label."
)

View File

@ -14,7 +14,7 @@ class UnknownDatabaseType(ValueError):
pass
class UnexpectedPredictionFormat(ValueError):
class UnexpectedLabelFormat(ValueError):
pass

View File

@ -0,0 +1,11 @@
import abc
class LabelMapper(abc.ABC):
@abc.abstractmethod
def map_labels(self, items):
raise NotImplementedError
def __call__(self, items):
return self.map_labels(items)

View File

@ -0,0 +1,22 @@
from typing import Mapping, Iterable
from image_prediction.exceptions import UnexpectedLabelFormat
from image_prediction.label_mapper.mapper import LabelMapper
class IndexLabelMapper(LabelMapper):
def __init__(self, labels: Mapping[int, str]):
self.__labels = labels
def __validate_int_prediction_format(self, index_label: int) -> None:
if not 0 <= index_label <= len(self.__labels):
raise UnexpectedLabelFormat(
f"Received index label '{index_label}' that has no associated string label."
)
def __map_label(self, index_label: int) -> str:
self.__validate_int_prediction_format(index_label)
return self.__labels[index_label]
def map_labels(self, index_labels: Iterable[int]) -> Iterable[str]:
return map(self.__map_label, index_labels)

View File

@ -134,8 +134,8 @@ def expected_batch_string_labels(expected_batch_numeric_labels, classes):
@pytest.fixture
def expected_batch_numeric_labels(input_batch, classes):
return random.choices(range(len(classes)), k=len(input_batch))
def expected_batch_numeric_labels(batch_size, classes):
return random.choices(range(len(classes)), k=batch_size)
@pytest.fixture

View File

@ -0,0 +1,6 @@
from image_prediction.label_mapper.mappers.numeric import IndexLabelMapper
def test_index_label_mapper(expected_batch_numeric_labels, expected_batch_string_labels, classes):
mapper = IndexLabelMapper(classes)
assert list(mapper(expected_batch_numeric_labels)) == expected_batch_string_labels