2022-04-19 14:20:54 +02:00

41 lines
1.2 KiB
Python

from itertools import starmap, product, repeat
from typing import Iterable
import numpy as np
from PIL.Image import Image
from frozendict import frozendict
from funcy import ilen
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
def transform_equal(a, b):
return (list(a) if isinstance(a, map) else a) == b
def images_equal(im1: Image, im2: Image, **kwargs):
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
def metadata_equal(mdat1: Iterable[dict], mdat2: Iterable[dict]):
return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2))
def image_sets_equal(ims1: Iterable[Image], ims2: Iterable[Image]):
ims1, ims2 = map(lambda x: sorted(map(image_to_normalized_tensor, x), key=np.mean), (ims1, ims2))
n = len(ims1)
assert isinstance(ims1, list)
assert len(ims2) == n
used = set()
covered = set()
for im1i, im2i in product(*repeat(range(n), 2)):
if im1i not in covered and im2i not in used and images_equal(ims1[im1i], ims2[im2i]):
covered.add(im1i)
used.add(im2i)
return len(covered) == len(used) == n