34 lines
1.4 KiB
Python
34 lines
1.4 KiB
Python
from itertools import chain
|
|
from typing import Iterable
|
|
|
|
from funcy import chunks, rpartial
|
|
|
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
|
from image_prediction.image_extractor.extractor import ImageExtractor
|
|
|
|
|
|
class ExtractorClassifier:
|
|
"""This class is responsible for orchestrating the pairing of classifications and image metadata. It extracts images
|
|
from an object and classifies them. Then it ties the classification together with the metadata. It returns an
|
|
iterable of dictionaries, where each dictionary has a field 'label' for the classification and possibly additional
|
|
fields for metadata -- metadata could be void.
|
|
"""
|
|
|
|
def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier):
|
|
self.classifier = image_classifier
|
|
self.extractor = image_extractor
|
|
|
|
def __process_batch(self, batch, batch_size):
|
|
images, metadata = zip(*batch)
|
|
|
|
predictions = self.classifier(images, batch_size)
|
|
responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
|
|
return responses
|
|
|
|
def __call__(self, obj, batch_size=16, **kwargs) -> Iterable[dict]:
|
|
|
|
image_metadata_pairs = self.extractor(obj, **kwargs)
|
|
batches = chunks(batch_size, image_metadata_pairs)
|
|
predictions = chain.from_iterable(map(rpartial(self.__process_batch, batch_size), batches))
|
|
yield from predictions
|