Matthias Bisping 9d58ae714f renaming
2022-03-27 17:55:01 +02:00

21 lines
832 B
Python

from itertools import chain
from typing import List
from PIL.Image import Image
from image_prediction.classifier.classifier import Classifier
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.utils import chunk_iterable
class ImageClassifier:
def __init__(self, estimator: Classifier, preprocessor: Preprocessor = None):
self.estimator = estimator
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
self.pipe = lambda batch: self.estimator(self.preprocessor(batch))
def predict(self, images: List[Image], batch_size=16):
batches = chunk_iterable(images, chunk_size=batch_size)
return chain(*map(self.pipe, batches))