from enum import Enum from operator import itemgetter from typing import Mapping, Iterable import numpy as np from funcy import rcompose, rpartial from image_prediction.exceptions import UnexpectedLabelFormat from image_prediction.label_mapper.mapper import LabelMapper class ProbabilityMapperKeys(Enum): LABEL = "label" PROBABILITIES = "probabilities" class ProbabilityMapper(LabelMapper): def __init__(self, labels: Mapping[int, str]): self.__labels = labels # String conversion in the middle due to floating point precision issues. # See: https://stackoverflow.com/questions/56820/round-doesnt-seem-to-be-rounding-properly self.__rounder = rcompose(rpartial(round, 4), str, float) def __validate_array_label_format(self, probabilities: np.ndarray) -> None: if not len(probabilities) == len(self.__labels): raise UnexpectedLabelFormat( f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})." ) def __map_array(self, probabilities: np.ndarray) -> dict: self.__validate_array_label_format(probabilities) cls2prob = dict( sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True) ) most_likely = [*cls2prob][0] return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: cls2prob} def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]: return map(self.__map_array, probabilities)