From 8bccec277fb17b383dab447f7604aee30c49a980 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Wed, 30 Mar 2022 15:54:18 +0200 Subject: [PATCH] added array label mapper --- .../label_mapper/mappers/numeric.py | 6 ++--- .../label_mapper/mappers/probability.py | 27 +++++++++++++++++++ test/unit_tests/label_mapper_test.py | 13 +++++---- 3 files changed, 38 insertions(+), 8 deletions(-) create mode 100644 image_prediction/label_mapper/mappers/probability.py diff --git a/image_prediction/label_mapper/mappers/numeric.py b/image_prediction/label_mapper/mappers/numeric.py index d166186..00e9f45 100644 --- a/image_prediction/label_mapper/mappers/numeric.py +++ b/image_prediction/label_mapper/mappers/numeric.py @@ -4,18 +4,18 @@ from image_prediction.exceptions import UnexpectedLabelFormat from image_prediction.label_mapper.mapper import LabelMapper -class IndexLabelMapper(LabelMapper): +class IndexMapper(LabelMapper): def __init__(self, labels: Mapping[int, str]): self.__labels = labels - def __validate_int_prediction_format(self, index_label: int) -> None: + def __validate_index_label_format(self, index_label: int) -> None: if not 0 <= index_label <= len(self.__labels): raise UnexpectedLabelFormat( f"Received index label '{index_label}' that has no associated string label." ) def __map_label(self, index_label: int) -> str: - self.__validate_int_prediction_format(index_label) + self.__validate_index_label_format(index_label) return self.__labels[index_label] def map_labels(self, index_labels: Iterable[int]) -> Iterable[str]: diff --git a/image_prediction/label_mapper/mappers/probability.py b/image_prediction/label_mapper/mappers/probability.py new file mode 100644 index 0000000..9808d5f --- /dev/null +++ b/image_prediction/label_mapper/mappers/probability.py @@ -0,0 +1,27 @@ +from operator import itemgetter +from typing import Mapping, Iterable + +import numpy as np + +from image_prediction.exceptions import UnexpectedLabelFormat +from image_prediction.label_mapper.mapper import LabelMapper + + +class ProbabilityMapper(LabelMapper): + def __init__(self, labels: Mapping[int, str]): + self.__labels = labels + + def __validate_array_label_format(self, probabilities: np.ndarray) -> None: + if not len(probabilities) == len(self.__labels): + raise UnexpectedLabelFormat( + f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})." + ) + + def __map_array(self, probabilities: np.ndarray) -> dict: + self.__validate_array_label_format(probabilities) + cls2prob = dict(sorted(zip(self.__labels, probabilities), key=itemgetter(1), reverse=True)) + most_likely = [*cls2prob][0] + return {"label": most_likely, "probabilities": cls2prob} + + def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]: + return map(self.__map_array, probabilities) diff --git a/test/unit_tests/label_mapper_test.py b/test/unit_tests/label_mapper_test.py index c420bbe..83de74a 100644 --- a/test/unit_tests/label_mapper_test.py +++ b/test/unit_tests/label_mapper_test.py @@ -1,11 +1,14 @@ -from image_prediction.label_mapper.mappers.numeric import IndexLabelMapper +from image_prediction.label_mapper.mappers.numeric import IndexMapper +from image_prediction.label_mapper.mappers.probability import ProbabilityMapper def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_string_labels, classes): - mapper = IndexLabelMapper(classes) + mapper = IndexMapper(classes) assert list(mapper(batch_of_expected_numeric_labels)) == batch_of_expected_string_labels -# def test_array_label_mapper(expected_batch_array, batch_of_expected_label_to_probability_mappings, classes): -# mapper = ProbabilityMapper(classes) -# assert list(mapper(expected_batch_numeric_labels)) == expected_batch_string_labels +def test_array_label_mapper( + batch_of_expected_probability_arrays, batch_of_expected_label_to_probability_mappings, classes +): + mapper = ProbabilityMapper(classes) + assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings