From 48737d943935c31cb43ddd60e003f7144337af40 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Mon, 28 Mar 2022 00:01:19 +0200 Subject: [PATCH] added extractor classifier --- .../classifier/image_classifier.py | 3 ++ .../extractor_classifier/__init__.py | 0 .../extractor_classifier.py | 29 +++++++++++++++++++ test/unit_tests/extractor_classifier_test.py | 16 ++++++++++ test/unit_tests/image_extractor_test.py | 2 +- 5 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 image_prediction/extractor_classifier/__init__.py create mode 100644 image_prediction/extractor_classifier/extractor_classifier.py create mode 100644 test/unit_tests/extractor_classifier_test.py diff --git a/image_prediction/classifier/image_classifier.py b/image_prediction/classifier/image_classifier.py index 4467881..212205f 100644 --- a/image_prediction/classifier/image_classifier.py +++ b/image_prediction/classifier/image_classifier.py @@ -22,3 +22,6 @@ class ImageClassifier: def predict(self, images: Iterable[Image], batch_size=16): batches = chunk_iterable(images, chunk_size=batch_size) return chain(*map(self.pipe, batches)) + + def __call__(self, images: Iterable[Image], batch_size=16): + return self.predict(images, batch_size=batch_size) diff --git a/image_prediction/extractor_classifier/__init__.py b/image_prediction/extractor_classifier/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/extractor_classifier/extractor_classifier.py b/image_prediction/extractor_classifier/extractor_classifier.py new file mode 100644 index 0000000..17bb5a9 --- /dev/null +++ b/image_prediction/extractor_classifier/extractor_classifier.py @@ -0,0 +1,29 @@ +from itertools import chain +from typing import Iterable + +from image_prediction.classifier.image_classifier import ImageClassifier +from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair +from image_prediction.utils import chunk_iterable + + +class ExtractorClassifier: + """Extracts images from an object and classifies them. When called, returns an iterable of dictionaries, where + each dictionary has a filed 'label' for the classification and possibly additional fields for metadata.""" + def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier): + self.classifier = image_classifier + self.extractor = image_extractor + + def __process_batch(self, batch): + try: + images, metadata = zip(*batch) + except ValueError: + return [] + + predictions = self.classifier(images) + responses = ({"label": lbl, **mdt} for lbl, mdt in zip(predictions, metadata)) + return responses + + def __call__(self, obj) -> Iterable[ImageMetadataPair]: + image_metadata_pairs = self.extractor(obj) + batches = chunk_iterable(image_metadata_pairs, chunk_size=16) + return chain(*map(self.__process_batch, batches)) diff --git a/test/unit_tests/extractor_classifier_test.py b/test/unit_tests/extractor_classifier_test.py new file mode 100644 index 0000000..7134010 --- /dev/null +++ b/test/unit_tests/extractor_classifier_test.py @@ -0,0 +1,16 @@ +from operator import itemgetter + +import pytest + +from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier + + +@pytest.mark.parametrize("extractor_type", ["mock"]) +@pytest.mark.parametrize("estimator_type", ["mock", "keras"]) +@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64]) +def test_extractor_classifier(image_extractor, image_classifier, images, expected_predictions): + extractor_classifier = ExtractorClassifier(image_extractor, image_classifier) + results = list(extractor_classifier(images)) + print(results) + labels = list(map(itemgetter("label"), results)) + assert labels == expected_predictions diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py index b4266b3..f7fecd2 100644 --- a/test/unit_tests/image_extractor_test.py +++ b/test/unit_tests/image_extractor_test.py @@ -3,6 +3,6 @@ import pytest @pytest.mark.parametrize("extractor_type", ["mock"]) @pytest.mark.parametrize("batch_size", [1, 2, 4]) -def test_image_extraction(image_extractor, images): +def test_image_extractor_mock(image_extractor, images): images_extracted, metadata = map(list, zip(*image_extractor(images))) assert images_extracted == images