From a5147c9a584b8c03f48cab2d5edcd80f0d303a1e Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Sun, 27 Mar 2022 23:05:27 +0200 Subject: [PATCH] added image extractor interface and mock --- image_prediction/exceptions.py | 6 +++++- image_prediction/image_extractor/__init__.py | 0 image_prediction/image_extractor/extractor.py | 16 ++++++++++++++++ .../image_extractor/extractors/__init__.py | 0 .../image_extractor/extractors/mock.py | 8 ++++++++ test/unit_tests/conftest.py | 11 ++++++++++- test/unit_tests/image_extractor_test.py | 8 ++++++++ 7 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 image_prediction/image_extractor/__init__.py create mode 100644 image_prediction/image_extractor/extractor.py create mode 100644 image_prediction/image_extractor/extractors/__init__.py create mode 100644 image_prediction/image_extractor/extractors/mock.py create mode 100644 test/unit_tests/image_extractor_test.py diff --git a/image_prediction/exceptions.py b/image_prediction/exceptions.py index 290b600..3098a10 100644 --- a/image_prediction/exceptions.py +++ b/image_prediction/exceptions.py @@ -1,2 +1,6 @@ class UnknownEstimatorAdapter(ValueError): - pass \ No newline at end of file + pass + + +class UnknownImageExtractor(ValueError): + pass diff --git a/image_prediction/image_extractor/__init__.py b/image_prediction/image_extractor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/image_extractor/extractor.py b/image_prediction/image_extractor/extractor.py new file mode 100644 index 0000000..1405aee --- /dev/null +++ b/image_prediction/image_extractor/extractor.py @@ -0,0 +1,16 @@ +import abc +from collections import namedtuple +from typing import Iterable + +ImageMetadataPair = namedtuple("ImageMetadataPair", ["image", "metadata"]) + + +class ImageExtractor(abc.ABC): + + @abc.abstractmethod + def extract(self, obj) -> Iterable[ImageMetadataPair]: + """Extracts images from an object""" + pass + + def __call__(self, obj): + return self.extract(obj) diff --git a/image_prediction/image_extractor/extractors/__init__.py b/image_prediction/image_extractor/extractors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/image_extractor/extractors/mock.py b/image_prediction/image_extractor/extractors/mock.py new file mode 100644 index 0000000..cfbc4d8 --- /dev/null +++ b/image_prediction/image_extractor/extractors/mock.py @@ -0,0 +1,8 @@ +from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair + + +class ImageExtractorMock(ImageExtractor): + + def extract(self, image_container): + for i, image in enumerate(image_container): + yield ImageMetadataPair(image, {"image_id": i}) diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index dabdfe6..6bd7680 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -8,7 +8,16 @@ from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock -from image_prediction.exceptions import UnknownEstimatorAdapter +from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor +from image_prediction.image_extractor.extractors.mock import ImageExtractorMock + + +@pytest.fixture +def image_extractor(extractor_type): + if extractor_type == "mock": + return ImageExtractorMock() + else: + raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.") @pytest.fixture diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py new file mode 100644 index 0000000..b4266b3 --- /dev/null +++ b/test/unit_tests/image_extractor_test.py @@ -0,0 +1,8 @@ +import pytest + + +@pytest.mark.parametrize("extractor_type", ["mock"]) +@pytest.mark.parametrize("batch_size", [1, 2, 4]) +def test_image_extraction(image_extractor, images): + images_extracted, metadata = map(list, zip(*image_extractor(images))) + assert images_extracted == images