diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index 86be406..f363d8d 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -49,15 +49,6 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", width_getter, height_getter = map(make_length_getter, ("width", "height")) -def merge_group(group, axis="y"): - - group = list(group) - current_pair = group.pop(0) - for pair in group: - if y2_getter(current_pair) == y1_getter(pair): - current_box = merge_pair_vertically(current_pair, pair) - - def merge_metadata_horizontally(m1, m2): m1, m2 = map(HorizontalKeyMapper, [m1, m2]) return merge_metadata(m1, m2) @@ -121,25 +112,73 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis): return im_aggr +class Stitcher: + @staticmethod + def groupby(pairs, coord): + coord_getter = make_coord_getter(coord) + pairs = sorted(pairs, key=coord_getter) + return map(compose(list, second), groupby(pairs, coord_getter)) + + def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair: + groups = self.groupby(pairs, "x1") + groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups)) + groups = map(partial(sorted, key=y1_getter), groups) + groups = map(merge_group, groups) + + +def merge_group(group): + def f(pairs): + current_pair = pairs.pop(0) + + to_remove = [] + + for pair in pairs: + if y2_getter(current_pair) == y1_getter(pair): + current_pair = merge_pair_vertically(current_pair, pair) + to_remove.append(pair) + + return [current_pair, *filter(lambda p: p not in to_remove, pairs)] + + pairs = list(group) + + while True: + new_pairs = f(deepcopy(pairs)) + if len(new_pairs) == len(pairs): + break + pairs = new_pairs + + return new_pairs + + ##################################### +def test_merge_group(vertical_merge_test_pairs): + pr1, pr2, pr_merged_expected = vertical_merge_test_pairs + prs_merged = merge_group([pr1, pr2]) + assert len(prs_merged) == 1 + assert_pair_equal(prs_merged[0], pr_merged_expected) + + +def assert_pair_equal(pr1, pr2): + assert pr1.metadata == pr2.metadata + assert images_equal(pr1.image, pr2.image) + + def test_merge_pairs_horizontally(horizontal_merge_test_pairs): pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs pr_merged = merge_pair_horizontally(pr1, pr2) - assert pr_merged.metadata == pr_merged_expected.metadata - images_equal(pr_merged.image, pr_merged_expected.image) + assert_pair_equal(pr_merged, pr_merged_expected) def test_merge_pairs_vertically(vertical_merge_test_pairs): pr1, pr2, pr_merged_expected = vertical_merge_test_pairs pr_merged = merge_pair_vertically(pr1, pr2) - assert pr_merged.metadata == pr_merged_expected.metadata - images_equal(pr_merged.image, pr_merged_expected.image) + assert_pair_equal(pr_merged, pr_merged_expected) def images_equal(im1: Image, im2: Image): - assert np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2)) + return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2)) @pytest.fixture @@ -200,28 +239,15 @@ def test_concat_images_horizontally(horizontal_merge_test_metadata): im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged]) im_merged = concat_images_horizontally(im1, im2, mdat_merged) assert im_merged.size == im_merged_expected.size - images_equal(im_merged, im_merged_expected) + assert images_equal(im_merged, im_merged_expected) def test_concat_images_vertically(vertical_merge_test_metadata): mdat1, mdat2, mdat_merged = vertical_merge_test_metadata - im1, im2, im_merged_expected = map(random_single_color_image_from_metadata, [mdat1, mdat2, mdat_merged]) + im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged]) im_merged = concat_images_vertically(im1, im2, mdat_merged) assert im_merged.size == im_merged_expected.size - - -class Stitcher: - @staticmethod - def groupby(pairs, coord): - coord_getter = make_coord_getter(coord) - pairs = sorted(pairs, key=coord_getter) - return map(compose(list, second), groupby(pairs, coord_getter)) - - def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair: - groups = self.groupby(pairs, "x1") - groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups)) - groups = map(partial(sorted, key=y1_getter), groups) - groups = map(merge_group, groups) + assert images_equal(im_merged, im_merged_expected) @pytest.mark.parametrize("width", [160])