From e276a5ec2773c01c2f7d99a606b4468f9a6d976f Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Thu, 7 Apr 2022 21:20:55 +0200 Subject: [PATCH] refactoring --- test/unit_tests/image_stitcher_test.py | 29 ++++++++++++-------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index 37c6fd5..1896fd9 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -151,12 +151,8 @@ def make_merger_aggregator(direction): def merge_group(group, direction): - reduce_group = make_merger_aggregator(direction) - - for g1, g2 in chunk_iterable(iterate(reduce_group, group), chunk_size=2): - if len(g1) == len(g2): - return g1 + return until_convergence(reduce_group, group) def merge_group_horizontally(group): @@ -204,20 +200,21 @@ class Stitcher: return pairs + def merge_along_both_axes(self, pairs): + + pairs = self.merge_along_axis(pairs, "x") + pairs = list(self.merge_along_axis(pairs, "y")) + + return pairs + def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]: + return until_convergence(self.merge_along_both_axes, pairs) - pairs = list(pairs) - n = len(pairs) - while True: - - pairs = self.merge_along_axis(pairs, "x") - pairs = list(self.merge_along_axis(pairs, "y")) - - if len(pairs) == n: - return pairs - - n = len(pairs) +def until_convergence(func, *args, **kwargs): + for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2): + if len(a) == len(b): + return a #####################################