added identity preprocessor; changed default preprocessor to idenitity
This commit is contained in:
parent
6b58756103
commit
a1c7dd4a8d
@ -6,7 +6,7 @@ from funcy import rcompose
|
||||
|
||||
from image_prediction.classifier.classifier import Classifier
|
||||
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
||||
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||
from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor
|
||||
from image_prediction.utils import chunk_iterable
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ class ImageClassifier:
|
||||
|
||||
def __init__(self, classifier: Classifier, preprocessor: Preprocessor = None):
|
||||
self.estimator = classifier
|
||||
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
|
||||
self.preprocessor = preprocessor if preprocessor else IdentityPreprocessor()
|
||||
self.pipe = rcompose(self.preprocessor, self.estimator)
|
||||
|
||||
def predict(self, images: Iterable[Image], batch_size=16):
|
||||
|
||||
@ -0,0 +1,11 @@
|
||||
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
||||
|
||||
|
||||
class IdentityPreprocessor(Preprocessor):
|
||||
|
||||
@staticmethod
|
||||
def preprocess(images):
|
||||
return images
|
||||
|
||||
def __call__(self, images):
|
||||
return self.preprocess(images)
|
||||
@ -12,6 +12,7 @@ from image_prediction.classifier.classifier import Classifier
|
||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
|
||||
from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock
|
||||
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownModelLoader
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
||||
@ -33,7 +34,7 @@ def image_extractor(extractor_type):
|
||||
|
||||
@pytest.fixture
|
||||
def image_classifier(classifier, monkeypatch, expected_predictions):
|
||||
return ImageClassifier(classifier)
|
||||
return ImageClassifier(classifier, preprocessor=BasicPreprocessor())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -3,6 +3,7 @@ import pytest
|
||||
from PIL import Image
|
||||
|
||||
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||
from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor
|
||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor, images_to_batch_tensor
|
||||
|
||||
|
||||
@ -35,3 +36,9 @@ def test_images_to_batch_tensor(images):
|
||||
def test_basic_preprocessor(images):
|
||||
tensor = BasicPreprocessor()(images)
|
||||
assert images_conversion_is_correct(images, tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 4, 6], scope="session")
|
||||
def test_identity_preprocessor(images):
|
||||
images_preprocessed = IdentityPreprocessor()(images)
|
||||
assert images_preprocessed == images
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user