Matthias Bisping 04cf0245ed formatting
2022-04-11 13:38:09 +02:00

32 lines
1.4 KiB
Python

from itertools import chain
from typing import Iterable
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.utils.generic import chunk_iterable
class ExtractorClassifier:
"""This class is responsible for orchestrating the pairing of classifications and image metadata. It extracts images
from an object and classifies them. Then it ties the classification together with the metadata. It returns an
iterable of dictionaries, where each dictionary has a field 'label' for the classification and possibly additional
fields for metadata -- metadata could be void.
"""
def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier):
self.classifier = image_classifier
self.extractor = image_extractor
def __process_batch(self, batch):
images, metadata = zip(*batch)
predictions = self.classifier(images)
responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
return responses
def __call__(self, obj, **kwargs) -> Iterable[ImageMetadataPair]:
image_metadata_pairs = self.extractor(obj, **kwargs)
batches = chunk_iterable(image_metadata_pairs, chunk_size=16)
predictions = chain.from_iterable(map(self.__process_batch, batches))
return predictions