refactoring
This commit is contained in:
parent
364111db89
commit
7d21b0a585
@ -5,22 +5,25 @@ from image_prediction.estimator.preprocessor.preprocessors.tensor_conversion imp
|
||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor, images_to_batch_tensor
|
||||
|
||||
|
||||
def image_conversion_is_correct(image):
|
||||
tensor = image_to_normalized_tensor(image)
|
||||
image_re = Image.fromarray(np.uint8(tensor * 255), mode="RGB")
|
||||
return image == image_re and tensor.ndim == 3
|
||||
|
||||
|
||||
def images_conversion_is_correct(images, tensor):
|
||||
return all([isinstance(tensor, np.ndarray), tensor.ndim == 4, tensor.shape[0] == len(images)])
|
||||
|
||||
|
||||
def test_image_to_tensor(images):
|
||||
|
||||
def inner(image):
|
||||
tensor = image_to_normalized_tensor(image)
|
||||
image_re = Image.fromarray(np.uint8(tensor * 255), mode="RGB")
|
||||
return image == image_re and tensor.ndim == 3
|
||||
|
||||
assert all(map(inner, images))
|
||||
assert all(map(image_conversion_is_correct, images))
|
||||
|
||||
|
||||
def test_images_to_batch_tensor(images):
|
||||
tensor = images_to_batch_tensor(images)
|
||||
assert isinstance(tensor, np.ndarray)
|
||||
assert tensor.ndim == 4
|
||||
assert images_conversion_is_correct(images, tensor)
|
||||
|
||||
|
||||
def test_basic_preprocessor(images):
|
||||
tensor = BasicPreprocessor().preprocess(images)
|
||||
assert tensor.ndim == 4
|
||||
tensor = BasicPreprocessor()(images)
|
||||
assert images_conversion_is_correct(images, tensor)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user