added identity preprocessor; changed default preprocessor to idenitity

This commit is contained in:
Matthias Bisping 2022-03-29 11:40:58 +02:00
parent 6b58756103
commit a1c7dd4a8d
4 changed files with 22 additions and 3 deletions

View File

@ -6,7 +6,7 @@ from funcy import rcompose
from image_prediction.classifier.classifier import Classifier
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor
from image_prediction.utils import chunk_iterable
@ -17,7 +17,7 @@ class ImageClassifier:
def __init__(self, classifier: Classifier, preprocessor: Preprocessor = None):
self.estimator = classifier
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
self.preprocessor = preprocessor if preprocessor else IdentityPreprocessor()
self.pipe = rcompose(self.preprocessor, self.estimator)
def predict(self, images: Iterable[Image], batch_size=16):

View File

@ -0,0 +1,11 @@
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
class IdentityPreprocessor(Preprocessor):
@staticmethod
def preprocess(images):
return images
def __call__(self, images):
return self.preprocess(images)

View File

@ -12,6 +12,7 @@ from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownModelLoader
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
@ -33,7 +34,7 @@ def image_extractor(extractor_type):
@pytest.fixture
def image_classifier(classifier, monkeypatch, expected_predictions):
return ImageClassifier(classifier)
return ImageClassifier(classifier, preprocessor=BasicPreprocessor())
@pytest.fixture

View File

@ -3,6 +3,7 @@ import pytest
from PIL import Image
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor, images_to_batch_tensor
@ -35,3 +36,9 @@ def test_images_to_batch_tensor(images):
def test_basic_preprocessor(images):
tensor = BasicPreprocessor()(images)
assert images_conversion_is_correct(images, tensor)
@pytest.mark.parametrize("batch_size", [0, 1, 2, 4, 6], scope="session")
def test_identity_preprocessor(images):
images_preprocessed = IdentityPreprocessor()(images)
assert images_preprocessed == images