added image extractor interface and mock

This commit is contained in:
Matthias Bisping 2022-03-27 23:05:27 +02:00
parent 4c939464b0
commit a5147c9a58
7 changed files with 47 additions and 2 deletions

View File

@ -1,2 +1,6 @@
class UnknownEstimatorAdapter(ValueError):
pass
pass
class UnknownImageExtractor(ValueError):
pass

View File

@ -0,0 +1,16 @@
import abc
from collections import namedtuple
from typing import Iterable
ImageMetadataPair = namedtuple("ImageMetadataPair", ["image", "metadata"])
class ImageExtractor(abc.ABC):
@abc.abstractmethod
def extract(self, obj) -> Iterable[ImageMetadataPair]:
"""Extracts images from an object"""
pass
def __call__(self, obj):
return self.extract(obj)

View File

@ -0,0 +1,8 @@
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
class ImageExtractorMock(ImageExtractor):
def extract(self, image_container):
for i, image in enumerate(image_container):
yield ImageMetadataPair(image, {"image_id": i})

View File

@ -8,7 +8,16 @@ from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock
from image_prediction.exceptions import UnknownEstimatorAdapter
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
@pytest.fixture
def image_extractor(extractor_type):
if extractor_type == "mock":
return ImageExtractorMock()
else:
raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.")
@pytest.fixture

View File

@ -0,0 +1,8 @@
import pytest
@pytest.mark.parametrize("extractor_type", ["mock"])
@pytest.mark.parametrize("batch_size", [1, 2, 4])
def test_image_extraction(image_extractor, images):
images_extracted, metadata = map(list, zip(*image_extractor(images)))
assert images_extracted == images