32 lines
1.4 KiB
Python
32 lines
1.4 KiB
Python
from itertools import chain
|
|
from typing import Iterable
|
|
|
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
|
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
|
from image_prediction.utils.generic import chunk_iterable
|
|
|
|
|
|
class ExtractorClassifier:
|
|
"""This class is responsible for orchestrating the pairing of classifications and image metadata. It extracts images
|
|
from an object and classifies them. Then it ties the classification together with the metadata. It returns an
|
|
iterable of dictionaries, where each dictionary has a field 'label' for the classification and possibly additional
|
|
fields for metadata -- metadata could be void.
|
|
"""
|
|
|
|
def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier):
|
|
self.classifier = image_classifier
|
|
self.extractor = image_extractor
|
|
|
|
def __process_batch(self, batch):
|
|
images, metadata = zip(*batch)
|
|
|
|
predictions = self.classifier(images)
|
|
responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
|
|
return responses
|
|
|
|
def __call__(self, obj, **kwargs) -> Iterable[ImageMetadataPair]:
|
|
image_metadata_pairs = self.extractor(obj, **kwargs)
|
|
batches = chunk_iterable(image_metadata_pairs, chunk_size=16)
|
|
predictions = chain.from_iterable(map(self.__process_batch, batches))
|
|
return predictions
|