added label mapper
This commit is contained in:
parent
99d8e921db
commit
8c7e3e29f5
@ -5,7 +5,7 @@ import numpy as np
|
||||
from PIL.Image import Image
|
||||
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
from image_prediction.exceptions import UnexpectedPredictionFormat
|
||||
from image_prediction.exceptions import UnexpectedLabelFormat
|
||||
from image_prediction.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@ -25,13 +25,13 @@ class Classifier:
|
||||
|
||||
def __validate_array_prediction_format(self, prediction):
|
||||
if not len(prediction) == len(self._classes):
|
||||
raise UnexpectedPredictionFormat(
|
||||
raise UnexpectedLabelFormat(
|
||||
f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)})."
|
||||
)
|
||||
|
||||
def __validate_int_prediction_format(self, prediction):
|
||||
if not 0 <= prediction <= len(self._classes):
|
||||
raise UnexpectedPredictionFormat(
|
||||
raise UnexpectedLabelFormat(
|
||||
f"Received class index '{prediction}' as prediction that has no associated class label."
|
||||
)
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ class UnknownDatabaseType(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class UnexpectedPredictionFormat(ValueError):
|
||||
class UnexpectedLabelFormat(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
0
image_prediction/label_mapper/__init__.py
Normal file
0
image_prediction/label_mapper/__init__.py
Normal file
11
image_prediction/label_mapper/mapper.py
Normal file
11
image_prediction/label_mapper/mapper.py
Normal file
@ -0,0 +1,11 @@
|
||||
import abc
|
||||
|
||||
|
||||
class LabelMapper(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def map_labels(self, items):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, items):
|
||||
return self.map_labels(items)
|
||||
0
image_prediction/label_mapper/mappers/__init__.py
Normal file
0
image_prediction/label_mapper/mappers/__init__.py
Normal file
22
image_prediction/label_mapper/mappers/numeric.py
Normal file
22
image_prediction/label_mapper/mappers/numeric.py
Normal file
@ -0,0 +1,22 @@
|
||||
from typing import Mapping, Iterable
|
||||
|
||||
from image_prediction.exceptions import UnexpectedLabelFormat
|
||||
from image_prediction.label_mapper.mapper import LabelMapper
|
||||
|
||||
|
||||
class IndexLabelMapper(LabelMapper):
|
||||
def __init__(self, labels: Mapping[int, str]):
|
||||
self.__labels = labels
|
||||
|
||||
def __validate_int_prediction_format(self, index_label: int) -> None:
|
||||
if not 0 <= index_label <= len(self.__labels):
|
||||
raise UnexpectedLabelFormat(
|
||||
f"Received index label '{index_label}' that has no associated string label."
|
||||
)
|
||||
|
||||
def __map_label(self, index_label: int) -> str:
|
||||
self.__validate_int_prediction_format(index_label)
|
||||
return self.__labels[index_label]
|
||||
|
||||
def map_labels(self, index_labels: Iterable[int]) -> Iterable[str]:
|
||||
return map(self.__map_label, index_labels)
|
||||
@ -134,8 +134,8 @@ def expected_batch_string_labels(expected_batch_numeric_labels, classes):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_batch_numeric_labels(input_batch, classes):
|
||||
return random.choices(range(len(classes)), k=len(input_batch))
|
||||
def expected_batch_numeric_labels(batch_size, classes):
|
||||
return random.choices(range(len(classes)), k=batch_size)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
6
test/unit_tests/label_mapper_test.py
Normal file
6
test/unit_tests/label_mapper_test.py
Normal file
@ -0,0 +1,6 @@
|
||||
from image_prediction.label_mapper.mappers.numeric import IndexLabelMapper
|
||||
|
||||
|
||||
def test_index_label_mapper(expected_batch_numeric_labels, expected_batch_string_labels, classes):
|
||||
mapper = IndexLabelMapper(classes)
|
||||
assert list(mapper(expected_batch_numeric_labels)) == expected_batch_string_labels
|
||||
Loading…
x
Reference in New Issue
Block a user