diff --git a/image_prediction/classifier/classifier.py b/image_prediction/classifier/classifier.py index 6998bda..fc1d50b 100644 --- a/image_prediction/classifier/classifier.py +++ b/image_prediction/classifier/classifier.py @@ -1,3 +1,4 @@ +from operator import itemgetter from typing import Mapping, List, Union, Tuple import numpy as np @@ -22,27 +23,31 @@ class Classifier: self.__estimator_adapter = estimator_adapter self._classes = classes - def __validate_dict_prediction_format(self, prediction): - if not max(prediction.keys) <= len(self._classes): - raise UnexpectedPredictionFormat(f"Received prediction in an unexpected format: {prediction}") - def __validate_array_prediction_format(self, prediction): if not len(prediction) == len(self._classes): raise UnexpectedPredictionFormat( f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)}." ) + def __validate_int_prediction_format(self, prediction): + if not 0 <= prediction <= len(self._classes): + raise UnexpectedPredictionFormat( + f"Received class index '{prediction}' as prediction that has no associated class label." + ) + + def __format_array_prediction_format(self, prediction): + cls2prob = dict(sorted(zip(self._classes, prediction), key=itemgetter(1))) + most_likely = [*cls2prob][0] + return {"label": most_likely, "probabilities": cls2prob} + def __format_prediction(self, prediction): if isinstance(prediction, int): + self.__validate_int_prediction_format(prediction) return self._classes[prediction] - elif isinstance(prediction, dict): - self.__validate_dict_prediction_format(prediction) - return {self._classes[cls_idx] for cls_idx, prob in prediction.items()} - elif isinstance(prediction, np.ndarray): self.__validate_array_prediction_format(prediction) - return dict(zip(self._classes, prediction)) + return self.__format_array_prediction_format(prediction) else: return prediction