from itertools import starmap, product, repeat from typing import Iterable import numpy as np from PIL.Image import Image from frozendict import frozendict from funcy import ilen from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor def transform_equal(a, b): return (list(a) if isinstance(a, map) else a) == b def images_equal(im1: Image, im2: Image, **kwargs): return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs) def metadata_equal(mdat1: Iterable[dict], mdat2: Iterable[dict]): return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2)) def image_sets_equal(ims1: Iterable[Image], ims2: Iterable[Image]): ims1, ims2 = map(lambda x: sorted(map(image_to_normalized_tensor, x), key=np.mean), (ims1, ims2)) n = len(ims1) assert isinstance(ims1, list) assert len(ims2) == n used = set() covered = set() for im1i, im2i in product(*repeat(range(n), 2)): if im1i not in covered and im2i not in used and images_equal(ims1[im1i], ims2[im2i]): covered.add(im1i) used.add(im2i) return len(covered) == len(used) == n