diff --git a/image_prediction/exceptions.py b/image_prediction/exceptions.py index 290b600..3098a10 100644 --- a/image_prediction/exceptions.py +++ b/image_prediction/exceptions.py @@ -1,2 +1,6 @@ class UnknownEstimatorAdapter(ValueError): - pass \ No newline at end of file + pass + + +class UnknownImageExtractor(ValueError): + pass diff --git a/image_prediction/image_extractor/__init__.py b/image_prediction/image_extractor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/image_extractor/extractor.py b/image_prediction/image_extractor/extractor.py new file mode 100644 index 0000000..1405aee --- /dev/null +++ b/image_prediction/image_extractor/extractor.py @@ -0,0 +1,16 @@ +import abc +from collections import namedtuple +from typing import Iterable + +ImageMetadataPair = namedtuple("ImageMetadataPair", ["image", "metadata"]) + + +class ImageExtractor(abc.ABC): + + @abc.abstractmethod + def extract(self, obj) -> Iterable[ImageMetadataPair]: + """Extracts images from an object""" + pass + + def __call__(self, obj): + return self.extract(obj) diff --git a/image_prediction/image_extractor/extractors/__init__.py b/image_prediction/image_extractor/extractors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/image_extractor/extractors/mock.py b/image_prediction/image_extractor/extractors/mock.py new file mode 100644 index 0000000..cfbc4d8 --- /dev/null +++ b/image_prediction/image_extractor/extractors/mock.py @@ -0,0 +1,8 @@ +from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair + + +class ImageExtractorMock(ImageExtractor): + + def extract(self, image_container): + for i, image in enumerate(image_container): + yield ImageMetadataPair(image, {"image_id": i}) diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index dabdfe6..6bd7680 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -8,7 +8,16 @@ from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock -from image_prediction.exceptions import UnknownEstimatorAdapter +from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor +from image_prediction.image_extractor.extractors.mock import ImageExtractorMock + + +@pytest.fixture +def image_extractor(extractor_type): + if extractor_type == "mock": + return ImageExtractorMock() + else: + raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.") @pytest.fixture diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py new file mode 100644 index 0000000..b4266b3 --- /dev/null +++ b/test/unit_tests/image_extractor_test.py @@ -0,0 +1,8 @@ +import pytest + + +@pytest.mark.parametrize("extractor_type", ["mock"]) +@pytest.mark.parametrize("batch_size", [1, 2, 4]) +def test_image_extraction(image_extractor, images): + images_extracted, metadata = map(list, zip(*image_extractor(images))) + assert images_extracted == images