diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index 37c6fd5..1896fd9 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -151,12 +151,8 @@ def make_merger_aggregator(direction): def merge_group(group, direction): - reduce_group = make_merger_aggregator(direction) - - for g1, g2 in chunk_iterable(iterate(reduce_group, group), chunk_size=2): - if len(g1) == len(g2): - return g1 + return until_convergence(reduce_group, group) def merge_group_horizontally(group): @@ -204,20 +200,21 @@ class Stitcher: return pairs + def merge_along_both_axes(self, pairs): + + pairs = self.merge_along_axis(pairs, "x") + pairs = list(self.merge_along_axis(pairs, "y")) + + return pairs + def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]: + return until_convergence(self.merge_along_both_axes, pairs) - pairs = list(pairs) - n = len(pairs) - while True: - - pairs = self.merge_along_axis(pairs, "x") - pairs = list(self.merge_along_axis(pairs, "y")) - - if len(pairs) == n: - return pairs - - n = len(pairs) +def until_convergence(func, *args, **kwargs): + for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2): + if len(a) == len(b): + return a #####################################