33 lines
998 B
Python
33 lines
998 B
Python
import random
|
|
from itertools import starmap
|
|
from operator import __eq__
|
|
|
|
import pytest
|
|
from PIL.Image import Image
|
|
from funcy import compose, first
|
|
|
|
from image_prediction.encoder.encoders.hash_encoder import HashEncoder
|
|
from image_prediction.encoder.encoders.hash_encoder import hash_image
|
|
from image_prediction.utils.generic import lift
|
|
|
|
|
|
def resize(image: Image):
|
|
factor = random.uniform(0.3, 2)
|
|
new_size = map(lambda x: int(x * factor), image.size)
|
|
return image.resize(new_size)
|
|
|
|
|
|
def close(a: str, b: str):
|
|
assert len(a) == len(b)
|
|
return sum(starmap(__eq__, zip(a, b))) / len(a) >= 0.75
|
|
|
|
|
|
@pytest.mark.xfail(reason="Stochastic test, may fail some amount of the time.")
|
|
def test_hash_encoder(images, hashed_images, base_patch_image):
|
|
encoder = HashEncoder()
|
|
assert list(encoder(images)) == hashed_images
|
|
|
|
hashed_resized = compose(first, encoder, lift(resize))([base_patch_image])
|
|
hashed = hash_image(base_patch_image)
|
|
assert close(hashed_resized, hashed)
|