tuning prediction format handling
This commit is contained in:
parent
8b15ac6df4
commit
81ab9a5f53
@ -1,3 +1,4 @@
|
|||||||
|
from operator import itemgetter
|
||||||
from typing import Mapping, List, Union, Tuple
|
from typing import Mapping, List, Union, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -22,27 +23,31 @@ class Classifier:
|
|||||||
self.__estimator_adapter = estimator_adapter
|
self.__estimator_adapter = estimator_adapter
|
||||||
self._classes = classes
|
self._classes = classes
|
||||||
|
|
||||||
def __validate_dict_prediction_format(self, prediction):
|
|
||||||
if not max(prediction.keys) <= len(self._classes):
|
|
||||||
raise UnexpectedPredictionFormat(f"Received prediction in an unexpected format: {prediction}")
|
|
||||||
|
|
||||||
def __validate_array_prediction_format(self, prediction):
|
def __validate_array_prediction_format(self, prediction):
|
||||||
if not len(prediction) == len(self._classes):
|
if not len(prediction) == len(self._classes):
|
||||||
raise UnexpectedPredictionFormat(
|
raise UnexpectedPredictionFormat(
|
||||||
f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)}."
|
f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __validate_int_prediction_format(self, prediction):
|
||||||
|
if not 0 <= prediction <= len(self._classes):
|
||||||
|
raise UnexpectedPredictionFormat(
|
||||||
|
f"Received class index '{prediction}' as prediction that has no associated class label."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __format_array_prediction_format(self, prediction):
|
||||||
|
cls2prob = dict(sorted(zip(self._classes, prediction), key=itemgetter(1)))
|
||||||
|
most_likely = [*cls2prob][0]
|
||||||
|
return {"label": most_likely, "probabilities": cls2prob}
|
||||||
|
|
||||||
def __format_prediction(self, prediction):
|
def __format_prediction(self, prediction):
|
||||||
if isinstance(prediction, int):
|
if isinstance(prediction, int):
|
||||||
|
self.__validate_int_prediction_format(prediction)
|
||||||
return self._classes[prediction]
|
return self._classes[prediction]
|
||||||
|
|
||||||
elif isinstance(prediction, dict):
|
|
||||||
self.__validate_dict_prediction_format(prediction)
|
|
||||||
return {self._classes[cls_idx] for cls_idx, prob in prediction.items()}
|
|
||||||
|
|
||||||
elif isinstance(prediction, np.ndarray):
|
elif isinstance(prediction, np.ndarray):
|
||||||
self.__validate_array_prediction_format(prediction)
|
self.__validate_array_prediction_format(prediction)
|
||||||
return dict(zip(self._classes, prediction))
|
return self.__format_array_prediction_format(prediction)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return prediction
|
return prediction
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user