preprocessor refactoring

This commit is contained in:
Matthias Bisping 2022-03-26 19:38:34 +01:00
parent ea298dacfa
commit 364111db89
4 changed files with 25 additions and 6 deletions

View File

@ -1,12 +1,12 @@
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
import abc
class Preprocessor:
def __init__(self):
class Preprocessor(abc.ABC):
@staticmethod
@abc.abstractmethod
def preprocess(batch):
pass
# def preprocess(self, batch):
# return (map(image_to_normalized_tensor, batch))
def __call__(self, batch):
return self.preprocess(batch)

View File

@ -0,0 +1,13 @@
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
class BasicPreprocessor(Preprocessor):
"""Converts images to tensors"""
@staticmethod
def preprocess(images):
return images_to_batch_tensor(images)
def __call__(self, images):
return self.preprocess(images)

View File

@ -1,6 +1,7 @@
import numpy as np
from PIL import Image
from image_prediction.estimator.preprocessor.preprocessors.tensor_conversion import BasicPreprocessor
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor, images_to_batch_tensor
@ -18,3 +19,8 @@ def test_images_to_batch_tensor(images):
tensor = images_to_batch_tensor(images)
assert isinstance(tensor, np.ndarray)
assert tensor.ndim == 4
def test_basic_preprocessor(images):
tensor = BasicPreprocessor().preprocess(images)
assert tensor.ndim == 4