refactoring

This commit is contained in:
Matthias Bisping 2022-04-07 21:20:55 +02:00
parent 7e6fe7cf11
commit e276a5ec27

View File

@ -151,12 +151,8 @@ def make_merger_aggregator(direction):
def merge_group(group, direction): def merge_group(group, direction):
reduce_group = make_merger_aggregator(direction) reduce_group = make_merger_aggregator(direction)
return until_convergence(reduce_group, group)
for g1, g2 in chunk_iterable(iterate(reduce_group, group), chunk_size=2):
if len(g1) == len(g2):
return g1
def merge_group_horizontally(group): def merge_group_horizontally(group):
@ -204,20 +200,21 @@ class Stitcher:
return pairs return pairs
def merge_along_both_axes(self, pairs):
pairs = self.merge_along_axis(pairs, "x")
pairs = list(self.merge_along_axis(pairs, "y"))
return pairs
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]: def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
return until_convergence(self.merge_along_both_axes, pairs)
pairs = list(pairs)
n = len(pairs)
while True: def until_convergence(func, *args, **kwargs):
for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2):
pairs = self.merge_along_axis(pairs, "x") if len(a) == len(b):
pairs = list(self.merge_along_axis(pairs, "y")) return a
if len(pairs) == n:
return pairs
n = len(pairs)
##################################### #####################################