stitcher test passes

This commit is contained in:
Matthias Bisping 2022-04-07 19:40:26 +02:00
parent 160973e2be
commit 8ac9fcb19f

View File

@ -6,6 +6,7 @@ from typing import Iterable, List
import fpdf
import numpy as np
import pdf2image
import pytest
from PIL import Image
from funcy import merge, second, compose, rpartial, juxt, rest, first, one, iterate
@ -117,24 +118,11 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
return im_aggr
class Stitcher:
@staticmethod
def groupby(pairs, coord):
coord_getter = make_coord_getter(coord)
pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter))
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
groups = self.groupby(pairs, "x1")
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
groups = map(partial(sorted, key=y1_getter), groups)
groups = map(merge_group, groups)
def make_merger_aggregator(direction):
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
head H and aggregates non-adjacent in the tail T.
"""
def merger_aggregator(pairs):
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
@ -175,9 +163,50 @@ def merge_group_vertically(group):
return merge_group(group, "y")
class Stitcher:
@staticmethod
def groupby(pairs, coord):
coord_getter = make_coord_getter(coord)
pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter))
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
pairs = list(pairs)
n = len(pairs)
while True:
groups = self.groupby(pairs, "x1")
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
groups = map(merge_group_vertically, groups)
pairs = chain.from_iterable(groups)
groups = self.groupby(pairs, "y1")
groups = chain.from_iterable(map(rpartial(self.groupby, "y2"), groups))
groups = map(merge_group_horizontally, groups)
pairs = list(chain.from_iterable(groups))
if len(pairs) == n:
return pairs
n = len(pairs)
#####################################
# @pytest.mark.parametrize("width", [160])
# @pytest.mark.parametrize("height", [90])
# @pytest.mark.parametrize("page_width", [int(160 * 1.1)])
# @pytest.mark.parametrize("page_height", [int(90 * 1.1)])
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
pair_stitched = Stitcher().stitch(patch_image_metadata_pairs)[0]
assert pair_stitched.metadata == base_patch_metadata
# pair_stitched.image.show()
# base_patch_image.show()
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
def test_merge_group_horizontally(horizontal_merge_test_pairs):
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
@ -230,8 +259,8 @@ def test_merge_pairs_vertically(vertical_merge_test_pairs):
assert pair_equal(pr_merged, pr_merged_expected)
def images_equal(im1: Image, im2: Image):
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2))
def images_equal(im1: Image, im2: Image, **kwargs):
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
@pytest.fixture
@ -285,6 +314,11 @@ def merge_test_metadata(base_patch_metadata):
return juxt(*repeat(deepcopy, 3))(base_patch_metadata)
@pytest.fixture
def base_patch_image(stitch_test_pdf):
return pdf2image.convert_from_bytes(stitch_test_pdf)[0]
def test_concat_images_horizontally(horizontal_merge_test_metadata):
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged])
@ -301,28 +335,15 @@ def test_concat_images_vertically(vertical_merge_test_metadata):
assert images_equal(im_merged, im_merged_expected)
@pytest.mark.parametrize("width", [160])
@pytest.mark.parametrize("height", [90])
@pytest.mark.parametrize("page_width", [int(160 * 1.1)])
@pytest.mark.parametrize("page_height", [int(90 * 1.1)])
@pytest.mark.skip()
def test_image_stitcher(patches_metadata, base_patch_metadata):
# noinspection PyTypeChecker
assert Stitcher().stitch(patch_image_metadata_pairs).metadata == base_patch_metadata
@pytest.fixture
def stitch_test_pdf(patch_image_metadata_pairs, width, height):
@pytest.mark.parametrize("width", [160])
@pytest.mark.parametrize("height", [90])
@pytest.mark.parametrize("page_width", [int(160 * 1.1)])
@pytest.mark.parametrize("page_height", [int(90 * 1.1)])
def test_partial_image_metadata_pairs(patch_image_metadata_pairs, page_width, page_height):
pdf = fpdf.FPDF(unit="pt", format=(page_width, page_height))
pdf = fpdf.FPDF(unit="pt", format=(width, height))
for pair in patch_image_metadata_pairs:
add_image(pdf, pair)
pdf.output("/tmp/bla.pdf")
return pdf.output(dest="S").encode("latin1")
@pytest.fixture