2022-03-28 00:01:19 +02:00

30 lines
1.2 KiB
Python

from itertools import chain
from typing import Iterable
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.utils import chunk_iterable
class ExtractorClassifier:
"""Extracts images from an object and classifies them. When called, returns an iterable of dictionaries, where
each dictionary has a filed 'label' for the classification and possibly additional fields for metadata."""
def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier):
self.classifier = image_classifier
self.extractor = image_extractor
def __process_batch(self, batch):
try:
images, metadata = zip(*batch)
except ValueError:
return []
predictions = self.classifier(images)
responses = ({"label": lbl, **mdt} for lbl, mdt in zip(predictions, metadata))
return responses
def __call__(self, obj) -> Iterable[ImageMetadataPair]:
image_metadata_pairs = self.extractor(obj)
batches = chunk_iterable(image_metadata_pairs, chunk_size=16)
return chain(*map(self.__process_batch, batches))