2023-02-01 14:38:55 +01:00

45 lines
1.3 KiB
Python

from functools import partial
from itertools import starmap, product, repeat
from typing import Iterable
import numpy as np
from PIL.Image import Image
from frozendict import frozendict
from funcy import ilen, compose, omit
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
from image_prediction.info import Info
from image_prediction.utils.generic import lift
def transform_equal(a, b):
return (list(a) if isinstance(a, map) else a) == b
def images_equal(im1: Image, im2: Image, **kwargs):
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
def metadata_equal(mdat1: Iterable[dict], mdat2: Iterable[dict]):
f = compose(set, lift(compose(frozendict, partial(omit, keys=[Info.XREF]))))
return f(mdat1) == f(mdat2)
def image_sets_equal(ims1: Iterable[Image], ims2: Iterable[Image]):
ims1, ims2 = map(lambda x: sorted(map(image_to_normalized_tensor, x), key=np.mean), (ims1, ims2))
n = len(ims1)
assert isinstance(ims1, list)
assert len(ims2) == n
used = set()
covered = set()
for im1i, im2i in product(*repeat(range(n), 2)):
if im1i not in covered and im2i not in used and images_equal(ims1[im1i], ims2[im2i]):
covered.add(im1i)
used.add(im2i)
return len(covered) == len(used) == n