diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index 1ec60d4..37c6fd5 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -168,6 +168,7 @@ def merge_group_vertically(group): class Stitcher: + @staticmethod def groupby(pairs, coord_getter): pairs = sorted(pairs, key=coord_getter) @@ -179,14 +180,27 @@ class Stitcher: def merge_along_axis(self, pairs, axis): + def group_pairs_by_c1(pairs): + return self.groupby(pairs, c1_getter) + + def group_by_c2(pairs): + return self.groupby(pairs, c2_getter) + + def group_pairs_within_groups_by_c2(groups): + return map(group_by_c2, groups) + + def merge_groups_along_orthogonal_axis(groups): + return map(group_merger, groups) + c1_getter = make_coord_getter(f"{self.other_axis(axis)}1") c2_getter = make_coord_getter(f"{self.other_axis(axis)}2") group_merger = make_group_merger(axis) - groups = self.groupby(pairs, c1_getter) - groups = chain.from_iterable(map(rpartial(self.groupby, c2_getter), groups)) - groups = map(group_merger, groups) - pairs = chain.from_iterable(groups) + groups_of_pairs_with_same_c1 = group_pairs_by_c1(pairs) + groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1) + groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2) + groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2) + pairs = chain(*groups_of_merged_pairs) return pairs