added extractor classifier
This commit is contained in:
parent
a5147c9a58
commit
48737d9439
@ -22,3 +22,6 @@ class ImageClassifier:
|
||||
def predict(self, images: Iterable[Image], batch_size=16):
|
||||
batches = chunk_iterable(images, chunk_size=batch_size)
|
||||
return chain(*map(self.pipe, batches))
|
||||
|
||||
def __call__(self, images: Iterable[Image], batch_size=16):
|
||||
return self.predict(images, batch_size=batch_size)
|
||||
|
||||
0
image_prediction/extractor_classifier/__init__.py
Normal file
0
image_prediction/extractor_classifier/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
from itertools import chain
|
||||
from typing import Iterable
|
||||
|
||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
||||
from image_prediction.utils import chunk_iterable
|
||||
|
||||
|
||||
class ExtractorClassifier:
|
||||
"""Extracts images from an object and classifies them. When called, returns an iterable of dictionaries, where
|
||||
each dictionary has a filed 'label' for the classification and possibly additional fields for metadata."""
|
||||
def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier):
|
||||
self.classifier = image_classifier
|
||||
self.extractor = image_extractor
|
||||
|
||||
def __process_batch(self, batch):
|
||||
try:
|
||||
images, metadata = zip(*batch)
|
||||
except ValueError:
|
||||
return []
|
||||
|
||||
predictions = self.classifier(images)
|
||||
responses = ({"label": lbl, **mdt} for lbl, mdt in zip(predictions, metadata))
|
||||
return responses
|
||||
|
||||
def __call__(self, obj) -> Iterable[ImageMetadataPair]:
|
||||
image_metadata_pairs = self.extractor(obj)
|
||||
batches = chunk_iterable(image_metadata_pairs, chunk_size=16)
|
||||
return chain(*map(self.__process_batch, batches))
|
||||
16
test/unit_tests/extractor_classifier_test.py
Normal file
16
test/unit_tests/extractor_classifier_test.py
Normal file
@ -0,0 +1,16 @@
|
||||
from operator import itemgetter
|
||||
|
||||
import pytest
|
||||
|
||||
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
|
||||
|
||||
|
||||
@pytest.mark.parametrize("extractor_type", ["mock"])
|
||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||
def test_extractor_classifier(image_extractor, image_classifier, images, expected_predictions):
|
||||
extractor_classifier = ExtractorClassifier(image_extractor, image_classifier)
|
||||
results = list(extractor_classifier(images))
|
||||
print(results)
|
||||
labels = list(map(itemgetter("label"), results))
|
||||
assert labels == expected_predictions
|
||||
@ -3,6 +3,6 @@ import pytest
|
||||
|
||||
@pytest.mark.parametrize("extractor_type", ["mock"])
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4])
|
||||
def test_image_extraction(image_extractor, images):
|
||||
def test_image_extractor_mock(image_extractor, images):
|
||||
images_extracted, metadata = map(list, zip(*image_extractor(images)))
|
||||
assert images_extracted == images
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user