diff --git a/image_prediction/classifier/image_classifier.py b/image_prediction/classifier/image_classifier.py index 5f77cb8..4467881 100644 --- a/image_prediction/classifier/image_classifier.py +++ b/image_prediction/classifier/image_classifier.py @@ -1,7 +1,8 @@ from itertools import chain -from typing import List +from typing import Iterable from PIL.Image import Image +from funcy import rcompose from image_prediction.classifier.classifier import Classifier from image_prediction.estimator.preprocessor.preprocessor import Preprocessor @@ -10,11 +11,14 @@ from image_prediction.utils import chunk_iterable class ImageClassifier: - def __init__(self, estimator: Classifier, preprocessor: Preprocessor = None): - self.estimator = estimator + """Combines a classifier with a preprocessing pipeline: Receives images, chunks into batches, converts to tensors, + applies transformations and finally sends to internal classifier. + """ + def __init__(self, classifier: Classifier, preprocessor: Preprocessor = None): + self.estimator = classifier self.preprocessor = preprocessor if preprocessor else BasicPreprocessor() - self.pipe = lambda batch: self.estimator(self.preprocessor(batch)) + self.pipe = rcompose(self.preprocessor, self.estimator) - def predict(self, images: List[Image], batch_size=16): + def predict(self, images: Iterable[Image], batch_size=16): batches = chunk_iterable(images, chunk_size=batch_size) return chain(*map(self.pipe, batches)) diff --git a/image_prediction/utils.py b/image_prediction/utils.py index 578fda2..d138381 100644 --- a/image_prediction/utils.py +++ b/image_prediction/utils.py @@ -1,6 +1,7 @@ import logging import tempfile from contextlib import contextmanager +from functools import reduce from itertools import takewhile, starmap, islice, repeat from operator import truth @@ -74,3 +75,8 @@ def show_banner(): @export def chunk_iterable(iterable, chunk_size): return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size))))) + + +def compose(func, *funcs): + funcs = [func, *funcs] + return lambda x: reduce(lambda acc, f: f(acc), funcs, x) diff --git a/requirements.txt b/requirements.txt index 217a846..a0a88e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,5 +19,5 @@ PDFNetPython3~=9.1.0 Pillow~=8.3.2 PyYAML~=5.4.1 scikit_learn~=0.24.2 - -pytest~=7.1.0 \ No newline at end of file +pytest~=7.1.0 +funcy==1.17 diff --git a/test/unit_tests/utils_test.py b/test/unit_tests/utils_test.py new file mode 100644 index 0000000..5032396 --- /dev/null +++ b/test/unit_tests/utils_test.py @@ -0,0 +1,6 @@ +from funcy import rcompose + + +def test_rcompose(): + f = rcompose(lambda x: x ** 2, str, lambda x: x * 2) + assert f(3) == "99"