29 lines
1.2 KiB
Python

from itertools import chain
from typing import Iterable
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.utils import chunk_iterable
class ExtractorClassifier:
"""Extracts images from an object and classifies them. When called, returns an iterable of dictionaries, where
each dictionary has a filed 'label' for the classification and possibly additional fields for metadata."""
def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier):
self.classifier = image_classifier
self.extractor = image_extractor
def __process_batch(self, batch):
images, metadata = zip(*batch)
predictions = self.classifier(images)
responses = ({"prediction": prd, **mdt} for prd, mdt in zip(predictions, metadata))
return responses
def __call__(self, obj) -> Iterable[ImageMetadataPair]:
image_metadata_pairs = self.extractor(obj)
batches = chunk_iterable(image_metadata_pairs, chunk_size=16)
predictions = chain.from_iterable(map(self.__process_batch, batches))
return predictions