preprocessor refactoring
This commit is contained in:
parent
ea298dacfa
commit
364111db89
@ -1,12 +1,12 @@
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
import abc
|
||||
|
||||
|
||||
class Preprocessor:
|
||||
def __init__(self):
|
||||
class Preprocessor(abc.ABC):
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def preprocess(batch):
|
||||
pass
|
||||
|
||||
# def preprocess(self, batch):
|
||||
# return (map(image_to_normalized_tensor, batch))
|
||||
|
||||
def __call__(self, batch):
|
||||
return self.preprocess(batch)
|
||||
|
||||
@ -0,0 +1,13 @@
|
||||
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
||||
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
|
||||
|
||||
|
||||
class BasicPreprocessor(Preprocessor):
|
||||
"""Converts images to tensors"""
|
||||
|
||||
@staticmethod
|
||||
def preprocess(images):
|
||||
return images_to_batch_tensor(images)
|
||||
|
||||
def __call__(self, images):
|
||||
return self.preprocess(images)
|
||||
@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from image_prediction.estimator.preprocessor.preprocessors.tensor_conversion import BasicPreprocessor
|
||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor, images_to_batch_tensor
|
||||
|
||||
|
||||
@ -18,3 +19,8 @@ def test_images_to_batch_tensor(images):
|
||||
tensor = images_to_batch_tensor(images)
|
||||
assert isinstance(tensor, np.ndarray)
|
||||
assert tensor.ndim == 4
|
||||
|
||||
|
||||
def test_basic_preprocessor(images):
|
||||
tensor = BasicPreprocessor().preprocess(images)
|
||||
assert tensor.ndim == 4
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user