from itertools import chain from typing import Iterable from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.utils.generic import chunk_iterable class ExtractorClassifier: """This class is responsible for orchestrating the pairing of classifications and image metadata. It extracts images from an object and classifies them. Then it ties the classification together with the metadata. It returns an iterable of dictionaries, where each dictionary has a field 'label' for the classification and possibly additional fields for metadata -- metadata could be void. """ def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier): self.classifier = image_classifier self.extractor = image_extractor def __process_batch(self, batch): images, metadata = zip(*batch) predictions = self.classifier(images) responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata)) return responses def __call__(self, obj, **kwargs) -> Iterable[ImageMetadataPair]: image_metadata_pairs = self.extractor(obj, **kwargs) batches = chunk_iterable(image_metadata_pairs, chunk_size=16) predictions = chain.from_iterable(map(self.__process_batch, batches)) return predictions