Julius Unverfehrt 265c61df1a RED-5107 add handling for 'broken' images: broken image parts are
replaced by blank images in the stitching process and completly broken
images are also replaced by blank images which are passed through and
are classified as 'other' with all_pased == False. This should be
changed in the future by introducing a new key to the response,
indicating that the image is not valid.
2022-08-30 12:30:23 +02:00

43 lines
1.7 KiB
Python

from enum import Enum
from operator import itemgetter
from typing import Mapping, Iterable, Union
import numpy as np
from funcy import rcompose, rpartial
from image_prediction.exceptions import UnexpectedLabelFormat
from image_prediction.label_mapper.mapper import LabelMapper
class ProbabilityMapperKeys(Enum):
LABEL = "label"
PROBABILITIES = "probabilities"
class ProbabilityMapper(LabelMapper):
def __init__(self, labels: Mapping[int, str]):
self.__labels = labels
# String conversion in the middle due to floating point precision issues.
# See: https://stackoverflow.com/questions/56820/round-doesnt-seem-to-be-rounding-properly
self.__rounder = rcompose(rpartial(round, 4), str, float)
def __validate_array_label_format(self, probabilities: np.ndarray) -> None:
if not len(probabilities) == len(self.__labels):
raise UnexpectedLabelFormat(
f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})."
)
def __map_array(self, probabilities: Union[np.ndarray, None]) -> Union[dict, None]:
if not isinstance(probabilities, np.ndarray) and not probabilities:
return None
self.__validate_array_label_format(probabilities)
cls2prob = dict(
sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True)
)
most_likely = [*cls2prob][0]
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: cls2prob}
def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]:
return map(self.__map_array, probabilities)