added extractor classifier

This commit is contained in:
Matthias Bisping 2022-03-28 00:01:19 +02:00
parent a5147c9a58
commit 48737d9439
5 changed files with 49 additions and 1 deletions

View File

@ -22,3 +22,6 @@ class ImageClassifier:
def predict(self, images: Iterable[Image], batch_size=16):
batches = chunk_iterable(images, chunk_size=batch_size)
return chain(*map(self.pipe, batches))
def __call__(self, images: Iterable[Image], batch_size=16):
return self.predict(images, batch_size=batch_size)

View File

@ -0,0 +1,29 @@
from itertools import chain
from typing import Iterable
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.utils import chunk_iterable
class ExtractorClassifier:
"""Extracts images from an object and classifies them. When called, returns an iterable of dictionaries, where
each dictionary has a filed 'label' for the classification and possibly additional fields for metadata."""
def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier):
self.classifier = image_classifier
self.extractor = image_extractor
def __process_batch(self, batch):
try:
images, metadata = zip(*batch)
except ValueError:
return []
predictions = self.classifier(images)
responses = ({"label": lbl, **mdt} for lbl, mdt in zip(predictions, metadata))
return responses
def __call__(self, obj) -> Iterable[ImageMetadataPair]:
image_metadata_pairs = self.extractor(obj)
batches = chunk_iterable(image_metadata_pairs, chunk_size=16)
return chain(*map(self.__process_batch, batches))

View File

@ -0,0 +1,16 @@
from operator import itemgetter
import pytest
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
@pytest.mark.parametrize("extractor_type", ["mock"])
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
def test_extractor_classifier(image_extractor, image_classifier, images, expected_predictions):
extractor_classifier = ExtractorClassifier(image_extractor, image_classifier)
results = list(extractor_classifier(images))
print(results)
labels = list(map(itemgetter("label"), results))
assert labels == expected_predictions

View File

@ -3,6 +3,6 @@ import pytest
@pytest.mark.parametrize("extractor_type", ["mock"])
@pytest.mark.parametrize("batch_size", [1, 2, 4])
def test_image_extraction(image_extractor, images):
def test_image_extractor_mock(image_extractor, images):
images_extracted, metadata = map(list, zip(*image_extractor(images)))
assert images_extracted == images