From 364111db89a5473be9224d92845ac4d811352b9e Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Sat, 26 Mar 2022 19:38:34 +0100 Subject: [PATCH] preprocessor refactoring --- .../estimator/preprocessor/preprocessor.py | 12 ++++++------ .../preprocessor/preprocessors/__init__.py | 0 .../preprocessor/preprocessors/tensor_conversion.py | 13 +++++++++++++ test/unit_tests/preprocessor_test.py | 6 ++++++ 4 files changed, 25 insertions(+), 6 deletions(-) create mode 100644 image_prediction/estimator/preprocessor/preprocessors/__init__.py create mode 100644 image_prediction/estimator/preprocessor/preprocessors/tensor_conversion.py diff --git a/image_prediction/estimator/preprocessor/preprocessor.py b/image_prediction/estimator/preprocessor/preprocessor.py index aab8f12..5f8b026 100644 --- a/image_prediction/estimator/preprocessor/preprocessor.py +++ b/image_prediction/estimator/preprocessor/preprocessor.py @@ -1,12 +1,12 @@ -from image_prediction.estimator.adapter.adapter import EstimatorAdapter +import abc -class Preprocessor: - def __init__(self): +class Preprocessor(abc.ABC): + + @staticmethod + @abc.abstractmethod + def preprocess(batch): pass - # def preprocess(self, batch): - # return (map(image_to_normalized_tensor, batch)) - def __call__(self, batch): return self.preprocess(batch) diff --git a/image_prediction/estimator/preprocessor/preprocessors/__init__.py b/image_prediction/estimator/preprocessor/preprocessors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/estimator/preprocessor/preprocessors/tensor_conversion.py b/image_prediction/estimator/preprocessor/preprocessors/tensor_conversion.py new file mode 100644 index 0000000..f8e5943 --- /dev/null +++ b/image_prediction/estimator/preprocessor/preprocessors/tensor_conversion.py @@ -0,0 +1,13 @@ +from image_prediction.estimator.preprocessor.preprocessor import Preprocessor +from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor + + +class BasicPreprocessor(Preprocessor): + """Converts images to tensors""" + + @staticmethod + def preprocess(images): + return images_to_batch_tensor(images) + + def __call__(self, images): + return self.preprocess(images) diff --git a/test/unit_tests/preprocessor_test.py b/test/unit_tests/preprocessor_test.py index 327e3fb..562e7e6 100644 --- a/test/unit_tests/preprocessor_test.py +++ b/test/unit_tests/preprocessor_test.py @@ -1,6 +1,7 @@ import numpy as np from PIL import Image +from image_prediction.estimator.preprocessor.preprocessors.tensor_conversion import BasicPreprocessor from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor, images_to_batch_tensor @@ -18,3 +19,8 @@ def test_images_to_batch_tensor(images): tensor = images_to_batch_tensor(images) assert isinstance(tensor, np.ndarray) assert tensor.ndim == 4 + + +def test_basic_preprocessor(images): + tensor = BasicPreprocessor().preprocess(images) + assert tensor.ndim == 4