33 lines
1.4 KiB
Python
33 lines
1.4 KiB
Python
from functools import partial
|
|
from operator import itemgetter
|
|
from typing import Mapping, Iterable
|
|
|
|
import numpy as np
|
|
from funcy import rcompose
|
|
|
|
from image_prediction.exceptions import UnexpectedLabelFormat
|
|
from image_prediction.label_mapper.mapper import LabelMapper
|
|
|
|
|
|
class ProbabilityMapper(LabelMapper):
|
|
def __init__(self, labels: Mapping[int, str]):
|
|
self.__labels = labels
|
|
# String conversion in the middle due to floating point precision issues.
|
|
# See: https://stackoverflow.com/questions/56820/round-doesnt-seem-to-be-rounding-properly
|
|
self.__rounder = rcompose(lambda d: round(d, 4), str, float)
|
|
|
|
def __validate_array_label_format(self, probabilities: np.ndarray) -> None:
|
|
if not len(probabilities) == len(self.__labels):
|
|
raise UnexpectedLabelFormat(
|
|
f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})."
|
|
)
|
|
|
|
def __map_array(self, probabilities: np.ndarray) -> dict:
|
|
self.__validate_array_label_format(probabilities)
|
|
cls2prob = dict(sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True))
|
|
most_likely = [*cls2prob][0]
|
|
return {"label": most_likely, "probabilities": cls2prob}
|
|
|
|
def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]:
|
|
return map(self.__map_array, probabilities)
|