from itertools import chain from typing import Iterable from PIL.Image import Image from funcy import rcompose from image_prediction.classifier.classifier import Classifier from image_prediction.estimator.preprocessor.preprocessor import Preprocessor from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor from image_prediction.utils import chunk_iterable class ImageClassifier: """Combines a classifier with a preprocessing pipeline: Receives images, chunks into batches, converts to tensors, applies transformations and finally sends to internal classifier. """ def __init__(self, classifier: Classifier, preprocessor: Preprocessor = None): self.estimator = classifier self.preprocessor = preprocessor if preprocessor else BasicPreprocessor() self.pipe = rcompose(self.preprocessor, self.estimator) def predict(self, images: Iterable[Image], batch_size=16): batches = chunk_iterable(images, chunk_size=batch_size) return chain(*map(self.pipe, batches))