2022-03-30 15:54:18 +02:00

28 lines
1.1 KiB
Python

from operator import itemgetter
from typing import Mapping, Iterable
import numpy as np
from image_prediction.exceptions import UnexpectedLabelFormat
from image_prediction.label_mapper.mapper import LabelMapper
class ProbabilityMapper(LabelMapper):
def __init__(self, labels: Mapping[int, str]):
self.__labels = labels
def __validate_array_label_format(self, probabilities: np.ndarray) -> None:
if not len(probabilities) == len(self.__labels):
raise UnexpectedLabelFormat(
f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})."
)
def __map_array(self, probabilities: np.ndarray) -> dict:
self.__validate_array_label_format(probabilities)
cls2prob = dict(sorted(zip(self.__labels, probabilities), key=itemgetter(1), reverse=True))
most_likely = [*cls2prob][0]
return {"label": most_likely, "probabilities": cls2prob}
def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]:
return map(self.__map_array, probabilities)