From a1c7dd4a8d7c999d06765a3e452fd75d851f08af Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Tue, 29 Mar 2022 11:40:58 +0200 Subject: [PATCH] added identity preprocessor; changed default preprocessor to idenitity --- image_prediction/classifier/image_classifier.py | 4 ++-- .../estimator/preprocessor/preprocessors/identity.py | 11 +++++++++++ test/unit_tests/conftest.py | 3 ++- test/unit_tests/preprocessor_test.py | 7 +++++++ 4 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 image_prediction/estimator/preprocessor/preprocessors/identity.py diff --git a/image_prediction/classifier/image_classifier.py b/image_prediction/classifier/image_classifier.py index 30b62ed..7bd21fc 100644 --- a/image_prediction/classifier/image_classifier.py +++ b/image_prediction/classifier/image_classifier.py @@ -6,7 +6,7 @@ from funcy import rcompose from image_prediction.classifier.classifier import Classifier from image_prediction.estimator.preprocessor.preprocessor import Preprocessor -from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor +from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor from image_prediction.utils import chunk_iterable @@ -17,7 +17,7 @@ class ImageClassifier: def __init__(self, classifier: Classifier, preprocessor: Preprocessor = None): self.estimator = classifier - self.preprocessor = preprocessor if preprocessor else BasicPreprocessor() + self.preprocessor = preprocessor if preprocessor else IdentityPreprocessor() self.pipe = rcompose(self.preprocessor, self.estimator) def predict(self, images: Iterable[Image], batch_size=16): diff --git a/image_prediction/estimator/preprocessor/preprocessors/identity.py b/image_prediction/estimator/preprocessor/preprocessors/identity.py new file mode 100644 index 0000000..dc5b335 --- /dev/null +++ b/image_prediction/estimator/preprocessor/preprocessors/identity.py @@ -0,0 +1,11 @@ +from image_prediction.estimator.preprocessor.preprocessor import Preprocessor + + +class IdentityPreprocessor(Preprocessor): + + @staticmethod + def preprocess(images): + return images + + def __call__(self, images): + return self.preprocess(images) diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index 6506c1d..a281cc7 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -12,6 +12,7 @@ from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock +from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownModelLoader from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractors.mock import ImageExtractorMock @@ -33,7 +34,7 @@ def image_extractor(extractor_type): @pytest.fixture def image_classifier(classifier, monkeypatch, expected_predictions): - return ImageClassifier(classifier) + return ImageClassifier(classifier, preprocessor=BasicPreprocessor()) @pytest.fixture diff --git a/test/unit_tests/preprocessor_test.py b/test/unit_tests/preprocessor_test.py index 72e8618..91219d8 100644 --- a/test/unit_tests/preprocessor_test.py +++ b/test/unit_tests/preprocessor_test.py @@ -3,6 +3,7 @@ import pytest from PIL import Image from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor +from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor, images_to_batch_tensor @@ -35,3 +36,9 @@ def test_images_to_batch_tensor(images): def test_basic_preprocessor(images): tensor = BasicPreprocessor()(images) assert images_conversion_is_correct(images, tensor) + + +@pytest.mark.parametrize("batch_size", [0, 1, 2, 4, 6], scope="session") +def test_identity_preprocessor(images): + images_preprocessed = IdentityPreprocessor()(images) + assert images_preprocessed == images