from itertools import chain from typing import Iterable from PIL.Image import Image from funcy import rcompose, chunks from image_prediction.classifier.classifier import Classifier from image_prediction.estimator.preprocessor.preprocessor import Preprocessor from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor from image_prediction.utils import get_logger logger = get_logger() class ImageClassifier: """Combines a classifier with a preprocessing pipeline: Receives images, chunks into batches, converts to tensors, applies transformations and finally sends to internal classifier. """ def __init__(self, classifier: Classifier, preprocessor: Preprocessor = None): self.estimator = classifier self.preprocessor = preprocessor if preprocessor else IdentityPreprocessor() self.pipe = rcompose(self.preprocessor, self.estimator) def predict(self, images: Iterable[Image], batch_size=16): batches = chunks(batch_size, images) predictions = chain.from_iterable(map(self.pipe, batches)) return predictions def __call__(self, images: Iterable[Image], batch_size=16): logger.debug("ImageClassifier.predict") yield from self.predict(images, batch_size=batch_size)