30 lines
1.2 KiB
Python
30 lines
1.2 KiB
Python
from itertools import chain
|
|
from typing import Iterable
|
|
|
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
|
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
|
from image_prediction.utils import chunk_iterable
|
|
|
|
|
|
class ExtractorClassifier:
|
|
"""Extracts images from an object and classifies them. When called, returns an iterable of dictionaries, where
|
|
each dictionary has a filed 'label' for the classification and possibly additional fields for metadata."""
|
|
def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier):
|
|
self.classifier = image_classifier
|
|
self.extractor = image_extractor
|
|
|
|
def __process_batch(self, batch):
|
|
try:
|
|
images, metadata = zip(*batch)
|
|
except ValueError:
|
|
return []
|
|
|
|
predictions = self.classifier(images)
|
|
responses = ({"label": lbl, **mdt} for lbl, mdt in zip(predictions, metadata))
|
|
return responses
|
|
|
|
def __call__(self, obj) -> Iterable[ImageMetadataPair]:
|
|
image_metadata_pairs = self.extractor(obj)
|
|
batches = chunk_iterable(image_metadata_pairs, chunk_size=16)
|
|
return chain(*map(self.__process_batch, batches))
|