from itertools import chain from typing import Iterable from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.utils import chunk_iterable class ExtractorClassifier: """Extracts images from an object and classifies them. When called, returns an iterable of dictionaries, where each dictionary has a filed 'label' for the classification and possibly additional fields for metadata.""" def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier): self.classifier = image_classifier self.extractor = image_extractor def __process_batch(self, batch): try: images, metadata = zip(*batch) except ValueError: return [] predictions = self.classifier(images) responses = ({"label": lbl, **mdt} for lbl, mdt in zip(predictions, metadata)) return responses def __call__(self, obj) -> Iterable[ImageMetadataPair]: image_metadata_pairs = self.extractor(obj) batches = chunk_iterable(image_metadata_pairs, chunk_size=16) return chain(*map(self.__process_batch, batches))