from itertools import chain from typing import Iterable from funcy import chunks, rpartial from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.image_extractor.extractor import ImageExtractor class ExtractorClassifier: """This class is responsible for orchestrating the pairing of classifications and image metadata. It extracts images from an object and classifies them. Then it ties the classification together with the metadata. It returns an iterable of dictionaries, where each dictionary has a field 'label' for the classification and possibly additional fields for metadata -- metadata could be void. """ def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier): self.classifier = image_classifier self.extractor = image_extractor def __process_batch(self, batch, batch_size): images, metadata = zip(*batch) predictions = self.classifier(images, batch_size) responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata)) return responses def __call__(self, obj, batch_size=16, **kwargs) -> Iterable[dict]: image_metadata_pairs = self.extractor(obj, **kwargs) batches = chunks(batch_size, image_metadata_pairs) predictions = chain.from_iterable(map(rpartial(self.__process_batch, batch_size), batches)) yield from predictions