corrected image set comparison; refactoring

This commit is contained in:
Matthias Bisping 2022-04-19 14:20:54 +02:00
parent 9d6b3e8f94
commit 7632fa8d7e
3 changed files with 36 additions and 11 deletions

View File

@ -7,13 +7,14 @@ from typing import List
import fitz
from PIL import Image
from funcy import rcompose, merge, zipdict
from funcy import rcompose, merge, zipdict, pluck, curry, compose
from tqdm import tqdm
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.info import Info
from image_prediction.stitching.stitching import stitch_pairs
from image_prediction.stitching.utils import validate_box_coords, validate_box_size
from image_prediction.utils.generic import lift
class ParsablePDFImageExtractor(ImageExtractor):
@ -129,11 +130,14 @@ def add_page_metadata(page, metadata):
def add_alpha_channel_info(doc, page, metadata):
xrefs = map(itemgetter("xref"), get_image_infos(page))
alpha = map(partial(has_alpha_channel, doc), xrefs)
alpha = ({Info.ALPHA: a} for a in alpha)
# alpha = map(dict, zip(repeat(Info.ALPHA), alpha))
metadata = starmap(merge, zip(alpha, metadata))
page_to_xrefs = compose(curry(pluck)("xref"), get_image_infos)
xref_to_alpha = partial(has_alpha_channel, doc)
page_to_alpha_value_per_image = compose(lift(xref_to_alpha), page_to_xrefs)
alpha_to_dict = compose(dict, lambda a: [(Info.ALPHA, a)])
page_to_alpha_mapping_per_image = compose(lift(alpha_to_dict), page_to_alpha_value_per_image)
metadata = starmap(merge, zip(page_to_alpha_mapping_per_image(page), metadata))
return metadata

View File

@ -1,5 +1,9 @@
from funcy import iterate, first
from funcy import iterate, first, curry, map
def until(cond, func, *args, **kwargs):
return first(filter(cond, iterate(func, *args, **kwargs)))
def lift(fn):
return curry(map)(fn)

View File

@ -1,8 +1,10 @@
from itertools import starmap, product, repeat
from typing import Iterable
import numpy as np
from PIL import Image
from PIL.Image import Image
from frozendict import frozendict
from funcy import ilen
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
@ -15,9 +17,24 @@ def images_equal(im1: Image, im2: Image, **kwargs):
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
def metadata_equal(mdat1: Iterable, mdat2: Iterable):
def metadata_equal(mdat1: Iterable[dict], mdat2: Iterable[dict]):
return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2))
def image_sets_equal(ims1, ims2):
return all(any(images_equal(im1, im2) for im2 in ims2) for im1 in ims1)
def image_sets_equal(ims1: Iterable[Image], ims2: Iterable[Image]):
ims1, ims2 = map(lambda x: sorted(map(image_to_normalized_tensor, x), key=np.mean), (ims1, ims2))
n = len(ims1)
assert isinstance(ims1, list)
assert len(ims2) == n
used = set()
covered = set()
for im1i, im2i in product(*repeat(range(n), 2)):
if im1i not in covered and im2i not in used and images_equal(ims1[im1i], ims2[im2i]):
covered.add(im1i)
used.add(im2i)
return len(covered) == len(used) == n