from functools import partial from itertools import starmap, product, repeat from typing import Iterable import numpy as np from PIL.Image import Image from frozendict import frozendict from funcy import ilen, compose, omit from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor from image_prediction.info import Info from image_prediction.utils.generic import lift def transform_equal(a, b): return (list(a) if isinstance(a, map) else a) == b def images_equal(im1: Image, im2: Image, **kwargs): return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs) def metadata_equal(mdat1: Iterable[dict], mdat2: Iterable[dict]): f = compose(set, lift(compose(frozendict, partial(omit, keys=[Info.XREF])))) return f(mdat1) == f(mdat2) def image_sets_equal(ims1: Iterable[Image], ims2: Iterable[Image]): ims1, ims2 = map(lambda x: sorted(map(image_to_normalized_tensor, x), key=np.mean), (ims1, ims2)) n = len(ims1) assert isinstance(ims1, list) assert len(ims2) == n used = set() covered = set() for im1i, im2i in product(*repeat(range(n), 2)): if im1i not in covered and im2i not in used and images_equal(ims1[im1i], ims2[im2i]): covered.add(im1i) used.add(im2i) return len(covered) == len(used) == n