refactoring

This commit is contained in:
Matthias Bisping 2022-03-27 18:13:58 +02:00
parent 9d58ae714f
commit 334dc79f7e
4 changed files with 23 additions and 7 deletions

View File

@ -1,7 +1,8 @@
from itertools import chain from itertools import chain
from typing import List from typing import Iterable
from PIL.Image import Image from PIL.Image import Image
from funcy import rcompose
from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.classifier import Classifier
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
@ -10,11 +11,14 @@ from image_prediction.utils import chunk_iterable
class ImageClassifier: class ImageClassifier:
def __init__(self, estimator: Classifier, preprocessor: Preprocessor = None): """Combines a classifier with a preprocessing pipeline: Receives images, chunks into batches, converts to tensors,
self.estimator = estimator applies transformations and finally sends to internal classifier.
"""
def __init__(self, classifier: Classifier, preprocessor: Preprocessor = None):
self.estimator = classifier
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor() self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
self.pipe = lambda batch: self.estimator(self.preprocessor(batch)) self.pipe = rcompose(self.preprocessor, self.estimator)
def predict(self, images: List[Image], batch_size=16): def predict(self, images: Iterable[Image], batch_size=16):
batches = chunk_iterable(images, chunk_size=batch_size) batches = chunk_iterable(images, chunk_size=batch_size)
return chain(*map(self.pipe, batches)) return chain(*map(self.pipe, batches))

View File

@ -1,6 +1,7 @@
import logging import logging
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from functools import reduce
from itertools import takewhile, starmap, islice, repeat from itertools import takewhile, starmap, islice, repeat
from operator import truth from operator import truth
@ -74,3 +75,8 @@ def show_banner():
@export @export
def chunk_iterable(iterable, chunk_size): def chunk_iterable(iterable, chunk_size):
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size))))) return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
def compose(func, *funcs):
funcs = [func, *funcs]
return lambda x: reduce(lambda acc, f: f(acc), funcs, x)

View File

@ -19,5 +19,5 @@ PDFNetPython3~=9.1.0
Pillow~=8.3.2 Pillow~=8.3.2
PyYAML~=5.4.1 PyYAML~=5.4.1
scikit_learn~=0.24.2 scikit_learn~=0.24.2
pytest~=7.1.0 pytest~=7.1.0
funcy==1.17

View File

@ -0,0 +1,6 @@
from funcy import rcompose
def test_rcompose():
f = rcompose(lambda x: x ** 2, str, lambda x: x * 2)
assert f(3) == "99"