refactoring

This commit is contained in:
Matthias Bisping 2022-03-27 18:13:58 +02:00
parent 9d58ae714f
commit 334dc79f7e
4 changed files with 23 additions and 7 deletions

View File

@ -1,7 +1,8 @@
from itertools import chain
from typing import List
from typing import Iterable
from PIL.Image import Image
from funcy import rcompose
from image_prediction.classifier.classifier import Classifier
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
@ -10,11 +11,14 @@ from image_prediction.utils import chunk_iterable
class ImageClassifier:
def __init__(self, estimator: Classifier, preprocessor: Preprocessor = None):
self.estimator = estimator
"""Combines a classifier with a preprocessing pipeline: Receives images, chunks into batches, converts to tensors,
applies transformations and finally sends to internal classifier.
"""
def __init__(self, classifier: Classifier, preprocessor: Preprocessor = None):
self.estimator = classifier
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
self.pipe = lambda batch: self.estimator(self.preprocessor(batch))
self.pipe = rcompose(self.preprocessor, self.estimator)
def predict(self, images: List[Image], batch_size=16):
def predict(self, images: Iterable[Image], batch_size=16):
batches = chunk_iterable(images, chunk_size=batch_size)
return chain(*map(self.pipe, batches))

View File

@ -1,6 +1,7 @@
import logging
import tempfile
from contextlib import contextmanager
from functools import reduce
from itertools import takewhile, starmap, islice, repeat
from operator import truth
@ -74,3 +75,8 @@ def show_banner():
@export
def chunk_iterable(iterable, chunk_size):
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
def compose(func, *funcs):
funcs = [func, *funcs]
return lambda x: reduce(lambda acc, f: f(acc), funcs, x)

View File

@ -19,5 +19,5 @@ PDFNetPython3~=9.1.0
Pillow~=8.3.2
PyYAML~=5.4.1
scikit_learn~=0.24.2
pytest~=7.1.0
pytest~=7.1.0
funcy==1.17

View File

@ -0,0 +1,6 @@
from funcy import rcompose
def test_rcompose():
f = rcompose(lambda x: x ** 2, str, lambda x: x * 2)
assert f(3) == "99"