added array label mapper
This commit is contained in:
parent
7f37f841dd
commit
8bccec277f
@ -4,18 +4,18 @@ from image_prediction.exceptions import UnexpectedLabelFormat
|
|||||||
from image_prediction.label_mapper.mapper import LabelMapper
|
from image_prediction.label_mapper.mapper import LabelMapper
|
||||||
|
|
||||||
|
|
||||||
class IndexLabelMapper(LabelMapper):
|
class IndexMapper(LabelMapper):
|
||||||
def __init__(self, labels: Mapping[int, str]):
|
def __init__(self, labels: Mapping[int, str]):
|
||||||
self.__labels = labels
|
self.__labels = labels
|
||||||
|
|
||||||
def __validate_int_prediction_format(self, index_label: int) -> None:
|
def __validate_index_label_format(self, index_label: int) -> None:
|
||||||
if not 0 <= index_label <= len(self.__labels):
|
if not 0 <= index_label <= len(self.__labels):
|
||||||
raise UnexpectedLabelFormat(
|
raise UnexpectedLabelFormat(
|
||||||
f"Received index label '{index_label}' that has no associated string label."
|
f"Received index label '{index_label}' that has no associated string label."
|
||||||
)
|
)
|
||||||
|
|
||||||
def __map_label(self, index_label: int) -> str:
|
def __map_label(self, index_label: int) -> str:
|
||||||
self.__validate_int_prediction_format(index_label)
|
self.__validate_index_label_format(index_label)
|
||||||
return self.__labels[index_label]
|
return self.__labels[index_label]
|
||||||
|
|
||||||
def map_labels(self, index_labels: Iterable[int]) -> Iterable[str]:
|
def map_labels(self, index_labels: Iterable[int]) -> Iterable[str]:
|
||||||
|
|||||||
27
image_prediction/label_mapper/mappers/probability.py
Normal file
27
image_prediction/label_mapper/mappers/probability.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from operator import itemgetter
|
||||||
|
from typing import Mapping, Iterable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from image_prediction.exceptions import UnexpectedLabelFormat
|
||||||
|
from image_prediction.label_mapper.mapper import LabelMapper
|
||||||
|
|
||||||
|
|
||||||
|
class ProbabilityMapper(LabelMapper):
|
||||||
|
def __init__(self, labels: Mapping[int, str]):
|
||||||
|
self.__labels = labels
|
||||||
|
|
||||||
|
def __validate_array_label_format(self, probabilities: np.ndarray) -> None:
|
||||||
|
if not len(probabilities) == len(self.__labels):
|
||||||
|
raise UnexpectedLabelFormat(
|
||||||
|
f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __map_array(self, probabilities: np.ndarray) -> dict:
|
||||||
|
self.__validate_array_label_format(probabilities)
|
||||||
|
cls2prob = dict(sorted(zip(self.__labels, probabilities), key=itemgetter(1), reverse=True))
|
||||||
|
most_likely = [*cls2prob][0]
|
||||||
|
return {"label": most_likely, "probabilities": cls2prob}
|
||||||
|
|
||||||
|
def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]:
|
||||||
|
return map(self.__map_array, probabilities)
|
||||||
@ -1,11 +1,14 @@
|
|||||||
from image_prediction.label_mapper.mappers.numeric import IndexLabelMapper
|
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
||||||
|
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
||||||
|
|
||||||
|
|
||||||
def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_string_labels, classes):
|
def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_string_labels, classes):
|
||||||
mapper = IndexLabelMapper(classes)
|
mapper = IndexMapper(classes)
|
||||||
assert list(mapper(batch_of_expected_numeric_labels)) == batch_of_expected_string_labels
|
assert list(mapper(batch_of_expected_numeric_labels)) == batch_of_expected_string_labels
|
||||||
|
|
||||||
|
|
||||||
# def test_array_label_mapper(expected_batch_array, batch_of_expected_label_to_probability_mappings, classes):
|
def test_array_label_mapper(
|
||||||
# mapper = ProbabilityMapper(classes)
|
batch_of_expected_probability_arrays, batch_of_expected_label_to_probability_mappings, classes
|
||||||
# assert list(mapper(expected_batch_numeric_labels)) == expected_batch_string_labels
|
):
|
||||||
|
mapper = ProbabilityMapper(classes)
|
||||||
|
assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user