From 8ac9fcb19f336a616103c7d819125ec5a9e0d7a1 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Thu, 7 Apr 2022 19:40:26 +0200 Subject: [PATCH] stitcher test passes --- test/unit_tests/image_stitcher_test.py | 87 ++++++++++++++++---------- 1 file changed, 54 insertions(+), 33 deletions(-) diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index e6c546e..267cd40 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -6,6 +6,7 @@ from typing import Iterable, List import fpdf import numpy as np +import pdf2image import pytest from PIL import Image from funcy import merge, second, compose, rpartial, juxt, rest, first, one, iterate @@ -117,24 +118,11 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis): return im_aggr -class Stitcher: - @staticmethod - def groupby(pairs, coord): - coord_getter = make_coord_getter(coord) - pairs = sorted(pairs, key=coord_getter) - return map(compose(list, second), groupby(pairs, coord_getter)) - - def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair: - groups = self.groupby(pairs, "x1") - groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups)) - groups = map(partial(sorted, key=y1_getter), groups) - groups = map(merge_group, groups) - - def make_merger_aggregator(direction): """Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the head H and aggregates non-adjacent in the tail T. """ + def merger_aggregator(pairs): def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair): """Keeps the image that is being merged with as the head and aggregates non-mergables in the tail.""" @@ -175,9 +163,50 @@ def merge_group_vertically(group): return merge_group(group, "y") +class Stitcher: + @staticmethod + def groupby(pairs, coord): + coord_getter = make_coord_getter(coord) + pairs = sorted(pairs, key=coord_getter) + return map(compose(list, second), groupby(pairs, coord_getter)) + + def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]: + + pairs = list(pairs) + n = len(pairs) + + while True: + + groups = self.groupby(pairs, "x1") + groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups)) + groups = map(merge_group_vertically, groups) + pairs = chain.from_iterable(groups) + + groups = self.groupby(pairs, "y1") + groups = chain.from_iterable(map(rpartial(self.groupby, "y2"), groups)) + groups = map(merge_group_horizontally, groups) + pairs = list(chain.from_iterable(groups)) + + if len(pairs) == n: + return pairs + n = len(pairs) + + ##################################### +# @pytest.mark.parametrize("width", [160]) +# @pytest.mark.parametrize("height", [90]) +# @pytest.mark.parametrize("page_width", [int(160 * 1.1)]) +# @pytest.mark.parametrize("page_height", [int(90 * 1.1)]) +def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image): + pair_stitched = Stitcher().stitch(patch_image_metadata_pairs)[0] + assert pair_stitched.metadata == base_patch_metadata + # pair_stitched.image.show() + # base_patch_image.show() + assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4) + + def test_merge_group_horizontally(horizontal_merge_test_pairs): pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs @@ -230,8 +259,8 @@ def test_merge_pairs_vertically(vertical_merge_test_pairs): assert pair_equal(pr_merged, pr_merged_expected) -def images_equal(im1: Image, im2: Image): - return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2)) +def images_equal(im1: Image, im2: Image, **kwargs): + return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs) @pytest.fixture @@ -285,6 +314,11 @@ def merge_test_metadata(base_patch_metadata): return juxt(*repeat(deepcopy, 3))(base_patch_metadata) +@pytest.fixture +def base_patch_image(stitch_test_pdf): + return pdf2image.convert_from_bytes(stitch_test_pdf)[0] + + def test_concat_images_horizontally(horizontal_merge_test_metadata): mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged]) @@ -301,28 +335,15 @@ def test_concat_images_vertically(vertical_merge_test_metadata): assert images_equal(im_merged, im_merged_expected) -@pytest.mark.parametrize("width", [160]) -@pytest.mark.parametrize("height", [90]) -@pytest.mark.parametrize("page_width", [int(160 * 1.1)]) -@pytest.mark.parametrize("page_height", [int(90 * 1.1)]) -@pytest.mark.skip() -def test_image_stitcher(patches_metadata, base_patch_metadata): - # noinspection PyTypeChecker - assert Stitcher().stitch(patch_image_metadata_pairs).metadata == base_patch_metadata +@pytest.fixture +def stitch_test_pdf(patch_image_metadata_pairs, width, height): - -@pytest.mark.parametrize("width", [160]) -@pytest.mark.parametrize("height", [90]) -@pytest.mark.parametrize("page_width", [int(160 * 1.1)]) -@pytest.mark.parametrize("page_height", [int(90 * 1.1)]) -def test_partial_image_metadata_pairs(patch_image_metadata_pairs, page_width, page_height): - - pdf = fpdf.FPDF(unit="pt", format=(page_width, page_height)) + pdf = fpdf.FPDF(unit="pt", format=(width, height)) for pair in patch_image_metadata_pairs: add_image(pdf, pair) - pdf.output("/tmp/bla.pdf") + return pdf.output(dest="S").encode("latin1") @pytest.fixture