45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
from functools import partial
|
|
from itertools import starmap, product, repeat
|
|
from typing import Iterable
|
|
|
|
import numpy as np
|
|
from PIL.Image import Image
|
|
from frozendict import frozendict
|
|
from funcy import ilen, compose, omit
|
|
|
|
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
|
from image_prediction.info import Info
|
|
from image_prediction.utils.generic import lift
|
|
|
|
|
|
def transform_equal(a, b):
|
|
return (list(a) if isinstance(a, map) else a) == b
|
|
|
|
|
|
def images_equal(im1: Image, im2: Image, **kwargs):
|
|
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
|
|
|
|
|
|
def metadata_equal(mdat1: Iterable[dict], mdat2: Iterable[dict]):
|
|
f = compose(set, lift(compose(frozendict, partial(omit, keys=[Info.XREF]))))
|
|
return f(mdat1) == f(mdat2)
|
|
|
|
|
|
def image_sets_equal(ims1: Iterable[Image], ims2: Iterable[Image]):
|
|
ims1, ims2 = map(lambda x: sorted(map(image_to_normalized_tensor, x), key=np.mean), (ims1, ims2))
|
|
|
|
n = len(ims1)
|
|
assert isinstance(ims1, list)
|
|
assert len(ims2) == n
|
|
|
|
used = set()
|
|
covered = set()
|
|
|
|
for im1i, im2i in product(*repeat(range(n), 2)):
|
|
|
|
if im1i not in covered and im2i not in used and images_equal(ims1[im1i], ims2[im2i]):
|
|
covered.add(im1i)
|
|
used.add(im2i)
|
|
|
|
return len(covered) == len(used) == n
|