27 lines
996 B
Python
27 lines
996 B
Python
from itertools import chain
|
|
from typing import List
|
|
|
|
from PIL.Image import Image
|
|
|
|
from image_prediction.estimator.estimator import Estimator
|
|
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
|
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
|
from image_prediction.utils import chunk_iterable
|
|
|
|
|
|
class Predictor:
|
|
def __init__(self, estimator: Estimator, preprocessor: Preprocessor = None):
|
|
self.estimator = estimator
|
|
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
|
|
self.pipe = lambda batch: self.estimator(self.preprocessor(batch))
|
|
|
|
def predict_images(self, images: List[Image], batch_size=4):
|
|
batches = chunk_iterable(images, chunk_size=batch_size)
|
|
batches = list(batches)
|
|
print(list(map(len, batches)))
|
|
for batch in batches:
|
|
print(len(batch))
|
|
print(self.pipe(batch))
|
|
|
|
return chain(*map(self.pipe, batches))
|