From 50b4d239cbece2e7cf0a24de8d611225556d6568 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Thu, 7 Apr 2022 18:05:15 +0200 Subject: [PATCH] group merging done --- test/unit_tests/image_stitcher_test.py | 67 ++++++++++++++++++-------- 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index fa85f69..10e5813 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -8,7 +8,7 @@ import fpdf import numpy as np import pytest from PIL import Image -from funcy import merge, second, compose, rpartial, juxt, rest, first +from funcy import merge, second, compose, rpartial, juxt, rest, first, one from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor from image_prediction.image_extractor.extractor import ImageMetadataPair @@ -39,10 +39,7 @@ def make_coord_getter(c): def make_pair_merger(direction): - return { - "y": merge_pair_vertically, - "x": merge_pair_horizontally - }[direction] + return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[direction] def make_length_getter(dim): @@ -133,10 +130,8 @@ class Stitcher: groups = map(merge_group, groups) -def merge_group(group, direction="y"): - +def merge_group(group, direction): def merge_with(aggregation_pair, pairs): - def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair): """Keeps the image that is being merged with as the head and aggregates non-mergables in the tail.""" aggr, non_aggr = juxt(first, rest)(pairs_aggr) @@ -163,31 +158,67 @@ def merge_group(group, direction="y"): return new_pairs +def merge_group_horizontally(group): + return merge_group(group, "x") + + +def merge_group_vertically(group): + return merge_group(group, "y") + + ##################################### -def test_merge_group(vertical_merge_test_pairs): - pr1, pr2, pr_merged_expected = vertical_merge_test_pairs - prs_merged = merge_group([pr1, pr2]) +def test_merge_group_horizontally(horizontal_merge_test_pairs): + pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs + + prs_merged = merge_group_horizontally([pr1, pr2]) assert len(prs_merged) == 1 - assert_pair_equal(prs_merged[0], pr_merged_expected) + assert pair_equal(prs_merged[0], pr_merged_expected) + + mdat3 = deepcopy(pr2.metadata) + mdat3[Info.HEIGHT] += 30 + mdat3[Info.Y2] += 30 + im3 = random_size_gray_image_from_metadata(mdat3) + pr3 = ImageMetadataPair(im3, mdat3) + + prs_merged = merge_group_horizontally([pr1, pr2, pr3]) + assert len(prs_merged) == 2 + assert one(partial(pair_equal, pr_merged_expected), prs_merged) -def assert_pair_equal(pr1, pr2): - assert pr1.metadata == pr2.metadata - assert images_equal(pr1.image, pr2.image) +def test_merge_group_vertically(vertical_merge_test_pairs): + pr1, pr2, pr_merged_expected = vertical_merge_test_pairs + + prs_merged = merge_group_vertically([pr1, pr2]) + assert len(prs_merged) == 1 + assert pair_equal(prs_merged[0], pr_merged_expected) + + mdat3 = deepcopy(pr2.metadata) + mdat3[Info.WIDTH] += 30 + mdat3[Info.X2] += 30 + im3 = random_size_gray_image_from_metadata(mdat3) + pr3 = ImageMetadataPair(im3, mdat3) + + prs_merged = merge_group_vertically([pr1, pr2, pr3]) + assert len(prs_merged) == 2 + assert one(partial(pair_equal, pr_merged_expected), prs_merged) + + +def pair_equal(pr1, pr2): + return pr1.metadata == pr2.metadata and images_equal(pr1.image, pr2.image) def test_merge_pairs_horizontally(horizontal_merge_test_pairs): pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs pr_merged = merge_pair_horizontally(pr1, pr2) - assert_pair_equal(pr_merged, pr_merged_expected) + assert pair_equal(pr_merged, pr_merged_expected) def test_merge_pairs_vertically(vertical_merge_test_pairs): pr1, pr2, pr_merged_expected = vertical_merge_test_pairs pr_merged = merge_pair_vertically(pr1, pr2) - assert_pair_equal(pr_merged, pr_merged_expected) + assert pair_equal(pr_merged, pr_merged_expected) def images_equal(im1: Image, im2: Image): @@ -218,7 +249,6 @@ def test_merge_metadata_vertically(vertical_merge_test_metadata): @pytest.fixture def horizontal_merge_test_metadata(merge_test_metadata): - mdat1, mdat2, mdat_merged = merge_test_metadata mdat2[Info.X1] = mdat1[Info.X2] @@ -231,7 +261,6 @@ def horizontal_merge_test_metadata(merge_test_metadata): @pytest.fixture def vertical_merge_test_metadata(merge_test_metadata): - mdat1, mdat2, mdat_merged = merge_test_metadata mdat2[Info.Y1] = mdat1[Info.Y2]