added array label mapper

This commit is contained in:
Matthias Bisping 2022-03-30 15:54:18 +02:00
parent 7f37f841dd
commit 8bccec277f
3 changed files with 38 additions and 8 deletions

View File

@ -4,18 +4,18 @@ from image_prediction.exceptions import UnexpectedLabelFormat
from image_prediction.label_mapper.mapper import LabelMapper
class IndexLabelMapper(LabelMapper):
class IndexMapper(LabelMapper):
def __init__(self, labels: Mapping[int, str]):
self.__labels = labels
def __validate_int_prediction_format(self, index_label: int) -> None:
def __validate_index_label_format(self, index_label: int) -> None:
if not 0 <= index_label <= len(self.__labels):
raise UnexpectedLabelFormat(
f"Received index label '{index_label}' that has no associated string label."
)
def __map_label(self, index_label: int) -> str:
self.__validate_int_prediction_format(index_label)
self.__validate_index_label_format(index_label)
return self.__labels[index_label]
def map_labels(self, index_labels: Iterable[int]) -> Iterable[str]:

View File

@ -0,0 +1,27 @@
from operator import itemgetter
from typing import Mapping, Iterable
import numpy as np
from image_prediction.exceptions import UnexpectedLabelFormat
from image_prediction.label_mapper.mapper import LabelMapper
class ProbabilityMapper(LabelMapper):
def __init__(self, labels: Mapping[int, str]):
self.__labels = labels
def __validate_array_label_format(self, probabilities: np.ndarray) -> None:
if not len(probabilities) == len(self.__labels):
raise UnexpectedLabelFormat(
f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})."
)
def __map_array(self, probabilities: np.ndarray) -> dict:
self.__validate_array_label_format(probabilities)
cls2prob = dict(sorted(zip(self.__labels, probabilities), key=itemgetter(1), reverse=True))
most_likely = [*cls2prob][0]
return {"label": most_likely, "probabilities": cls2prob}
def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]:
return map(self.__map_array, probabilities)

View File

@ -1,11 +1,14 @@
from image_prediction.label_mapper.mappers.numeric import IndexLabelMapper
from image_prediction.label_mapper.mappers.numeric import IndexMapper
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_string_labels, classes):
mapper = IndexLabelMapper(classes)
mapper = IndexMapper(classes)
assert list(mapper(batch_of_expected_numeric_labels)) == batch_of_expected_string_labels
# def test_array_label_mapper(expected_batch_array, batch_of_expected_label_to_probability_mappings, classes):
# mapper = ProbabilityMapper(classes)
# assert list(mapper(expected_batch_numeric_labels)) == expected_batch_string_labels
def test_array_label_mapper(
batch_of_expected_probability_arrays, batch_of_expected_label_to_probability_mappings, classes
):
mapper = ProbabilityMapper(classes)
assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings