2022-03-28 00:01:19 +02:00

28 lines
1.2 KiB
Python

from itertools import chain
from typing import Iterable
from PIL.Image import Image
from funcy import rcompose
from image_prediction.classifier.classifier import Classifier
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.utils import chunk_iterable
class ImageClassifier:
"""Combines a classifier with a preprocessing pipeline: Receives images, chunks into batches, converts to tensors,
applies transformations and finally sends to internal classifier.
"""
def __init__(self, classifier: Classifier, preprocessor: Preprocessor = None):
self.estimator = classifier
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
self.pipe = rcompose(self.preprocessor, self.estimator)
def predict(self, images: Iterable[Image], batch_size=16):
batches = chunk_iterable(images, chunk_size=batch_size)
return chain(*map(self.pipe, batches))
def __call__(self, images: Iterable[Image], batch_size=16):
return self.predict(images, batch_size=batch_size)