from itertools import chain from typing import List from PIL.Image import Image from image_prediction.classifier.classifier import Classifier from image_prediction.estimator.preprocessor.preprocessor import Preprocessor from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor from image_prediction.utils import chunk_iterable class ImageClassifier: def __init__(self, estimator: Classifier, preprocessor: Preprocessor = None): self.estimator = estimator self.preprocessor = preprocessor if preprocessor else BasicPreprocessor() self.pipe = lambda batch: self.estimator(self.preprocessor(batch)) def predict(self, images: List[Image], batch_size=16): batches = chunk_iterable(images, chunk_size=batch_size) return chain(*map(self.pipe, batches))