corrected image set comparison; refactoring
This commit is contained in:
parent
9d6b3e8f94
commit
7632fa8d7e
@ -7,13 +7,14 @@ from typing import List
|
|||||||
|
|
||||||
import fitz
|
import fitz
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from funcy import rcompose, merge, zipdict
|
from funcy import rcompose, merge, zipdict, pluck, curry, compose
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
||||||
from image_prediction.info import Info
|
from image_prediction.info import Info
|
||||||
from image_prediction.stitching.stitching import stitch_pairs
|
from image_prediction.stitching.stitching import stitch_pairs
|
||||||
from image_prediction.stitching.utils import validate_box_coords, validate_box_size
|
from image_prediction.stitching.utils import validate_box_coords, validate_box_size
|
||||||
|
from image_prediction.utils.generic import lift
|
||||||
|
|
||||||
|
|
||||||
class ParsablePDFImageExtractor(ImageExtractor):
|
class ParsablePDFImageExtractor(ImageExtractor):
|
||||||
@ -129,11 +130,14 @@ def add_page_metadata(page, metadata):
|
|||||||
|
|
||||||
|
|
||||||
def add_alpha_channel_info(doc, page, metadata):
|
def add_alpha_channel_info(doc, page, metadata):
|
||||||
xrefs = map(itemgetter("xref"), get_image_infos(page))
|
|
||||||
alpha = map(partial(has_alpha_channel, doc), xrefs)
|
page_to_xrefs = compose(curry(pluck)("xref"), get_image_infos)
|
||||||
alpha = ({Info.ALPHA: a} for a in alpha)
|
xref_to_alpha = partial(has_alpha_channel, doc)
|
||||||
# alpha = map(dict, zip(repeat(Info.ALPHA), alpha))
|
page_to_alpha_value_per_image = compose(lift(xref_to_alpha), page_to_xrefs)
|
||||||
metadata = starmap(merge, zip(alpha, metadata))
|
alpha_to_dict = compose(dict, lambda a: [(Info.ALPHA, a)])
|
||||||
|
page_to_alpha_mapping_per_image = compose(lift(alpha_to_dict), page_to_alpha_value_per_image)
|
||||||
|
|
||||||
|
metadata = starmap(merge, zip(page_to_alpha_mapping_per_image(page), metadata))
|
||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,9 @@
|
|||||||
from funcy import iterate, first
|
from funcy import iterate, first, curry, map
|
||||||
|
|
||||||
|
|
||||||
def until(cond, func, *args, **kwargs):
|
def until(cond, func, *args, **kwargs):
|
||||||
return first(filter(cond, iterate(func, *args, **kwargs)))
|
return first(filter(cond, iterate(func, *args, **kwargs)))
|
||||||
|
|
||||||
|
|
||||||
|
def lift(fn):
|
||||||
|
return curry(map)(fn)
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
|
from itertools import starmap, product, repeat
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL.Image import Image
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
from funcy import ilen
|
||||||
|
|
||||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
||||||
|
|
||||||
@ -15,9 +17,24 @@ def images_equal(im1: Image, im2: Image, **kwargs):
|
|||||||
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
|
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def metadata_equal(mdat1: Iterable, mdat2: Iterable):
|
def metadata_equal(mdat1: Iterable[dict], mdat2: Iterable[dict]):
|
||||||
return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2))
|
return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2))
|
||||||
|
|
||||||
|
|
||||||
def image_sets_equal(ims1, ims2):
|
def image_sets_equal(ims1: Iterable[Image], ims2: Iterable[Image]):
|
||||||
return all(any(images_equal(im1, im2) for im2 in ims2) for im1 in ims1)
|
ims1, ims2 = map(lambda x: sorted(map(image_to_normalized_tensor, x), key=np.mean), (ims1, ims2))
|
||||||
|
|
||||||
|
n = len(ims1)
|
||||||
|
assert isinstance(ims1, list)
|
||||||
|
assert len(ims2) == n
|
||||||
|
|
||||||
|
used = set()
|
||||||
|
covered = set()
|
||||||
|
|
||||||
|
for im1i, im2i in product(*repeat(range(n), 2)):
|
||||||
|
|
||||||
|
if im1i not in covered and im2i not in used and images_equal(ims1[im1i], ims2[im2i]):
|
||||||
|
covered.add(im1i)
|
||||||
|
used.add(im2i)
|
||||||
|
|
||||||
|
return len(covered) == len(used) == n
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user