corrected image set comparison; refactoring

This commit is contained in:
Matthias Bisping 2022-04-19 14:20:54 +02:00
parent 9d6b3e8f94
commit 7632fa8d7e
3 changed files with 36 additions and 11 deletions

View File

@ -7,13 +7,14 @@ from typing import List
import fitz import fitz
from PIL import Image from PIL import Image
from funcy import rcompose, merge, zipdict from funcy import rcompose, merge, zipdict, pluck, curry, compose
from tqdm import tqdm from tqdm import tqdm
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.info import Info from image_prediction.info import Info
from image_prediction.stitching.stitching import stitch_pairs from image_prediction.stitching.stitching import stitch_pairs
from image_prediction.stitching.utils import validate_box_coords, validate_box_size from image_prediction.stitching.utils import validate_box_coords, validate_box_size
from image_prediction.utils.generic import lift
class ParsablePDFImageExtractor(ImageExtractor): class ParsablePDFImageExtractor(ImageExtractor):
@ -129,11 +130,14 @@ def add_page_metadata(page, metadata):
def add_alpha_channel_info(doc, page, metadata): def add_alpha_channel_info(doc, page, metadata):
xrefs = map(itemgetter("xref"), get_image_infos(page))
alpha = map(partial(has_alpha_channel, doc), xrefs) page_to_xrefs = compose(curry(pluck)("xref"), get_image_infos)
alpha = ({Info.ALPHA: a} for a in alpha) xref_to_alpha = partial(has_alpha_channel, doc)
# alpha = map(dict, zip(repeat(Info.ALPHA), alpha)) page_to_alpha_value_per_image = compose(lift(xref_to_alpha), page_to_xrefs)
metadata = starmap(merge, zip(alpha, metadata)) alpha_to_dict = compose(dict, lambda a: [(Info.ALPHA, a)])
page_to_alpha_mapping_per_image = compose(lift(alpha_to_dict), page_to_alpha_value_per_image)
metadata = starmap(merge, zip(page_to_alpha_mapping_per_image(page), metadata))
return metadata return metadata

View File

@ -1,5 +1,9 @@
from funcy import iterate, first from funcy import iterate, first, curry, map
def until(cond, func, *args, **kwargs): def until(cond, func, *args, **kwargs):
return first(filter(cond, iterate(func, *args, **kwargs))) return first(filter(cond, iterate(func, *args, **kwargs)))
def lift(fn):
return curry(map)(fn)

View File

@ -1,8 +1,10 @@
from itertools import starmap, product, repeat
from typing import Iterable from typing import Iterable
import numpy as np import numpy as np
from PIL import Image from PIL.Image import Image
from frozendict import frozendict from frozendict import frozendict
from funcy import ilen
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
@ -15,9 +17,24 @@ def images_equal(im1: Image, im2: Image, **kwargs):
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs) return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
def metadata_equal(mdat1: Iterable, mdat2: Iterable): def metadata_equal(mdat1: Iterable[dict], mdat2: Iterable[dict]):
return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2)) return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2))
def image_sets_equal(ims1, ims2): def image_sets_equal(ims1: Iterable[Image], ims2: Iterable[Image]):
return all(any(images_equal(im1, im2) for im2 in ims2) for im1 in ims1) ims1, ims2 = map(lambda x: sorted(map(image_to_normalized_tensor, x), key=np.mean), (ims1, ims2))
n = len(ims1)
assert isinstance(ims1, list)
assert len(ims2) == n
used = set()
covered = set()
for im1i, im2i in product(*repeat(range(n), 2)):
if im1i not in covered and im2i not in used and images_equal(ims1[im1i], ims2[im2i]):
covered.add(im1i)
used.add(im2i)
return len(covered) == len(used) == n