from itertools import chain from typing import List from PIL.Image import Image from image_prediction.estimator.estimator import Estimator from image_prediction.estimator.preprocessor.preprocessor import Preprocessor from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor from image_prediction.utils import chunk_iterable class Predictor: def __init__(self, estimator: Estimator, preprocessor: Preprocessor = None): self.estimator = estimator self.preprocessor = preprocessor if preprocessor else BasicPreprocessor() self.pipe = lambda batch: self.estimator(self.preprocessor(batch)) def predict_images(self, images: List[Image], batch_size=4): batches = chunk_iterable(images, chunk_size=batch_size) batches = list(batches) print(list(map(len, batches))) for batch in batches: print(len(batch)) print(self.pipe(batch)) return chain(*map(self.pipe, batches))