diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 848b60d..2d783b0 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -7,13 +7,14 @@ from typing import List import fitz from PIL import Image -from funcy import rcompose, merge, zipdict +from funcy import rcompose, merge, zipdict, pluck, curry, compose from tqdm import tqdm from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.info import Info from image_prediction.stitching.stitching import stitch_pairs from image_prediction.stitching.utils import validate_box_coords, validate_box_size +from image_prediction.utils.generic import lift class ParsablePDFImageExtractor(ImageExtractor): @@ -129,11 +130,14 @@ def add_page_metadata(page, metadata): def add_alpha_channel_info(doc, page, metadata): - xrefs = map(itemgetter("xref"), get_image_infos(page)) - alpha = map(partial(has_alpha_channel, doc), xrefs) - alpha = ({Info.ALPHA: a} for a in alpha) - # alpha = map(dict, zip(repeat(Info.ALPHA), alpha)) - metadata = starmap(merge, zip(alpha, metadata)) + + page_to_xrefs = compose(curry(pluck)("xref"), get_image_infos) + xref_to_alpha = partial(has_alpha_channel, doc) + page_to_alpha_value_per_image = compose(lift(xref_to_alpha), page_to_xrefs) + alpha_to_dict = compose(dict, lambda a: [(Info.ALPHA, a)]) + page_to_alpha_mapping_per_image = compose(lift(alpha_to_dict), page_to_alpha_value_per_image) + + metadata = starmap(merge, zip(page_to_alpha_mapping_per_image(page), metadata)) return metadata diff --git a/image_prediction/utils/generic.py b/image_prediction/utils/generic.py index 98cf612..9b25640 100644 --- a/image_prediction/utils/generic.py +++ b/image_prediction/utils/generic.py @@ -1,5 +1,9 @@ -from funcy import iterate, first +from funcy import iterate, first, curry, map def until(cond, func, *args, **kwargs): return first(filter(cond, iterate(func, *args, **kwargs))) + + +def lift(fn): + return curry(map)(fn) diff --git a/test/utils/comparison.py b/test/utils/comparison.py index daefd1b..f2677ce 100644 --- a/test/utils/comparison.py +++ b/test/utils/comparison.py @@ -1,8 +1,10 @@ +from itertools import starmap, product, repeat from typing import Iterable import numpy as np -from PIL import Image +from PIL.Image import Image from frozendict import frozendict +from funcy import ilen from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor @@ -15,9 +17,24 @@ def images_equal(im1: Image, im2: Image, **kwargs): return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs) -def metadata_equal(mdat1: Iterable, mdat2: Iterable): +def metadata_equal(mdat1: Iterable[dict], mdat2: Iterable[dict]): return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2)) -def image_sets_equal(ims1, ims2): - return all(any(images_equal(im1, im2) for im2 in ims2) for im1 in ims1) +def image_sets_equal(ims1: Iterable[Image], ims2: Iterable[Image]): + ims1, ims2 = map(lambda x: sorted(map(image_to_normalized_tensor, x), key=np.mean), (ims1, ims2)) + + n = len(ims1) + assert isinstance(ims1, list) + assert len(ims2) == n + + used = set() + covered = set() + + for im1i, im2i in product(*repeat(range(n), 2)): + + if im1i not in covered and im2i not in used and images_equal(ims1[im1i], ims2[im2i]): + covered.add(im1i) + used.add(im2i) + + return len(covered) == len(used) == n