refactoring
This commit is contained in:
parent
9d58ae714f
commit
334dc79f7e
@ -1,7 +1,8 @@
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import List
|
from typing import Iterable
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
from funcy import rcompose
|
||||||
|
|
||||||
from image_prediction.classifier.classifier import Classifier
|
from image_prediction.classifier.classifier import Classifier
|
||||||
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
||||||
@ -10,11 +11,14 @@ from image_prediction.utils import chunk_iterable
|
|||||||
|
|
||||||
|
|
||||||
class ImageClassifier:
|
class ImageClassifier:
|
||||||
def __init__(self, estimator: Classifier, preprocessor: Preprocessor = None):
|
"""Combines a classifier with a preprocessing pipeline: Receives images, chunks into batches, converts to tensors,
|
||||||
self.estimator = estimator
|
applies transformations and finally sends to internal classifier.
|
||||||
|
"""
|
||||||
|
def __init__(self, classifier: Classifier, preprocessor: Preprocessor = None):
|
||||||
|
self.estimator = classifier
|
||||||
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
|
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
|
||||||
self.pipe = lambda batch: self.estimator(self.preprocessor(batch))
|
self.pipe = rcompose(self.preprocessor, self.estimator)
|
||||||
|
|
||||||
def predict(self, images: List[Image], batch_size=16):
|
def predict(self, images: Iterable[Image], batch_size=16):
|
||||||
batches = chunk_iterable(images, chunk_size=batch_size)
|
batches = chunk_iterable(images, chunk_size=batch_size)
|
||||||
return chain(*map(self.pipe, batches))
|
return chain(*map(self.pipe, batches))
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from functools import reduce
|
||||||
from itertools import takewhile, starmap, islice, repeat
|
from itertools import takewhile, starmap, islice, repeat
|
||||||
from operator import truth
|
from operator import truth
|
||||||
|
|
||||||
@ -74,3 +75,8 @@ def show_banner():
|
|||||||
@export
|
@export
|
||||||
def chunk_iterable(iterable, chunk_size):
|
def chunk_iterable(iterable, chunk_size):
|
||||||
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
|
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
|
||||||
|
|
||||||
|
|
||||||
|
def compose(func, *funcs):
|
||||||
|
funcs = [func, *funcs]
|
||||||
|
return lambda x: reduce(lambda acc, f: f(acc), funcs, x)
|
||||||
|
|||||||
@ -19,5 +19,5 @@ PDFNetPython3~=9.1.0
|
|||||||
Pillow~=8.3.2
|
Pillow~=8.3.2
|
||||||
PyYAML~=5.4.1
|
PyYAML~=5.4.1
|
||||||
scikit_learn~=0.24.2
|
scikit_learn~=0.24.2
|
||||||
|
|
||||||
pytest~=7.1.0
|
pytest~=7.1.0
|
||||||
|
funcy==1.17
|
||||||
|
|||||||
6
test/unit_tests/utils_test.py
Normal file
6
test/unit_tests/utils_test.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from funcy import rcompose
|
||||||
|
|
||||||
|
|
||||||
|
def test_rcompose():
|
||||||
|
f = rcompose(lambda x: x ** 2, str, lambda x: x * 2)
|
||||||
|
assert f(3) == "99"
|
||||||
Loading…
x
Reference in New Issue
Block a user