From e9489287bdad498f9e9f57ee934367d2b8636e7b Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Tue, 29 Mar 2022 23:56:22 +0200 Subject: [PATCH] support for array prediction format --- image_prediction/classifier/classifier.py | 14 ++++++++++++-- image_prediction/estimator/adapter/adapter.py | 5 +---- image_prediction/exceptions.py | 2 +- .../extractor_classifier/extractor_classifier.py | 2 +- image_prediction/redai_adapter/model.py | 11 +++++++---- test/unit_tests/conftest.py | 3 +++ test/unit_tests/extractor_classifier_test.py | 2 +- 7 files changed, 26 insertions(+), 13 deletions(-) diff --git a/image_prediction/classifier/classifier.py b/image_prediction/classifier/classifier.py index fbcfdf1..e027292 100644 --- a/image_prediction/classifier/classifier.py +++ b/image_prediction/classifier/classifier.py @@ -23,18 +23,28 @@ class Classifier: self.__estimator_adapter = estimator_adapter self._classes = classes - def __validate_prediction_format(self, prediction): + def __validate_dict_prediction_format(self, prediction): if not max(prediction.keys) <= len(self._classes): raise UnexpectedPredictionFormat(f"Received prediction in an unexpected format: {prediction}") + def __validate_array_prediction_format(self, prediction): + if not len(prediction) == len(self._classes): + raise UnexpectedPredictionFormat( + f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)}." + ) + def __format_prediction(self, prediction): if isinstance(prediction, int): return self._classes[prediction] elif isinstance(prediction, dict): - self.__validate_prediction_format(prediction) + self.__validate_dict_prediction_format(prediction) return {self._classes[cls_idx] for cls_idx, prob in prediction.items()} + elif isinstance(prediction, np.ndarray): + self.__validate_array_prediction_format(prediction) + return dict(zip(self._classes, prediction)) + else: return prediction diff --git a/image_prediction/estimator/adapter/adapter.py b/image_prediction/estimator/adapter/adapter.py index 5e98bda..692178e 100644 --- a/image_prediction/estimator/adapter/adapter.py +++ b/image_prediction/estimator/adapter/adapter.py @@ -3,7 +3,4 @@ class EstimatorAdapter: self.estimator = estimator def predict(self, batch): - return self.estimator.predict(batch) - - def predict_proba(self, batch): - return self.estimator.predict_proba(batch) + return self.estimator(batch) diff --git a/image_prediction/exceptions.py b/image_prediction/exceptions.py index 03da3ec..730286d 100644 --- a/image_prediction/exceptions.py +++ b/image_prediction/exceptions.py @@ -19,4 +19,4 @@ class UnexpectedPredictionFormat(ValueError): class IncorrectInstantiation(RuntimeError): - pass \ No newline at end of file + pass diff --git a/image_prediction/extractor_classifier/extractor_classifier.py b/image_prediction/extractor_classifier/extractor_classifier.py index bb8198e..43e7788 100644 --- a/image_prediction/extractor_classifier/extractor_classifier.py +++ b/image_prediction/extractor_classifier/extractor_classifier.py @@ -21,7 +21,7 @@ class ExtractorClassifier: return [] predictions = self.classifier(images) - responses = ({"label": lbl, **mdt} for lbl, mdt in zip(predictions, metadata)) + responses = ({"prediction": prd, **mdt} for prd, mdt in zip(predictions, metadata)) return responses def __call__(self, obj) -> Iterable[ImageMetadataPair]: diff --git a/image_prediction/redai_adapter/model.py b/image_prediction/redai_adapter/model.py index d646f24..4ae7b79 100644 --- a/image_prediction/redai_adapter/model.py +++ b/image_prediction/redai_adapter/model.py @@ -6,11 +6,14 @@ class PredictionModelHandle: def __init__(self, model_handle): self.__model_handle = model_handle + self.__predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict) + self.__predict_proba = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba) def predict(self, *args, **kwargs): - predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict) - return predict(*args, **kwargs) + return self.__predict(*args, **kwargs) def predict_proba(self, *args, **kwargs): - predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba) - return predict(*args, **kwargs) \ No newline at end of file + return self.__predict_proba(*args, **kwargs) + + def __call__(self, *args, **kwargs): + return self.predict_proba(*args, **kwargs) diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index 0f65d24..64a8585 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -52,6 +52,9 @@ class EstimatorMock: def predict(batch): return [None for _ in batch] + def __call__(self, batch): + return self.predict(batch) + @pytest.fixture def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch): diff --git a/test/unit_tests/extractor_classifier_test.py b/test/unit_tests/extractor_classifier_test.py index 9ce37d0..e42692f 100644 --- a/test/unit_tests/extractor_classifier_test.py +++ b/test/unit_tests/extractor_classifier_test.py @@ -11,5 +11,5 @@ from image_prediction.extractor_classifier.extractor_classifier import Extractor def test_extractor_classifier(image_extractor, image_classifier, images, expected_predictions): extractor_classifier = ExtractorClassifier(image_extractor, image_classifier) results = extractor_classifier(images) - labels = list(map(itemgetter("label"), results)) + labels = list(map(itemgetter("prediction"), results)) assert labels == expected_predictions