From 15c0b730349dbc1b6a168b021892f6e81c81ecdf Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Tue, 29 Mar 2022 23:41:43 +0200 Subject: [PATCH] support for different prediction formats --- image_prediction/classifier/classifier.py | 22 +++++++++++++++++++--- image_prediction/exceptions.py | 6 +++++- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/image_prediction/classifier/classifier.py b/image_prediction/classifier/classifier.py index 7d2a02b..fbcfdf1 100644 --- a/image_prediction/classifier/classifier.py +++ b/image_prediction/classifier/classifier.py @@ -4,6 +4,7 @@ import numpy as np from PIL.Image import Image from image_prediction.estimator.adapter.adapter import EstimatorAdapter +from image_prediction.exceptions import UnexpectedPredictionFormat from image_prediction.utils import get_logger logger = get_logger() @@ -15,19 +16,34 @@ class Classifier: an EstimatorAdapter must be implemented. Args: - estimator_adapter: adapter for a given estimator backend; expected to be a classifier that returns numeric - labels as predictions + estimator_adapter: adapter for a given estimator backend; expected to be a classifier that returns mappings + from numeric labels to probabilities as predictions or numeric labels classes: mapping from a numerical label to a human-readable label for classes """ self.__estimator_adapter = estimator_adapter self._classes = classes + def __validate_prediction_format(self, prediction): + if not max(prediction.keys) <= len(self._classes): + raise UnexpectedPredictionFormat(f"Received prediction in an unexpected format: {prediction}") + + def __format_prediction(self, prediction): + if isinstance(prediction, int): + return self._classes[prediction] + + elif isinstance(prediction, dict): + self.__validate_prediction_format(prediction) + return {self._classes[cls_idx] for cls_idx, prob in prediction.items()} + + else: + return prediction + def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]: if not isinstance(batch, tuple) and batch.shape[0] == 0: return [] - return [self._classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)] + return list(map(self.__format_prediction, self.__estimator_adapter.predict(batch))) def __call__(self, batch: np.array) -> List[str]: return self.predict(batch) diff --git a/image_prediction/exceptions.py b/image_prediction/exceptions.py index 9cc0f5d..03da3ec 100644 --- a/image_prediction/exceptions.py +++ b/image_prediction/exceptions.py @@ -14,5 +14,9 @@ class UnknownDatabaseType(ValueError): pass -class IncorrectInstantiation(RuntimeError): +class UnexpectedPredictionFormat(ValueError): pass + + +class IncorrectInstantiation(RuntimeError): + pass \ No newline at end of file