diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index 267cd40..1ec60d4 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -40,8 +40,12 @@ def make_coord_getter(c): }[c] -def make_pair_merger(direction): - return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[direction] +def make_pair_merger(axis): + return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis] + + +def make_group_merger(axis): + return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis] def make_length_getter(dim): @@ -165,11 +169,27 @@ def merge_group_vertically(group): class Stitcher: @staticmethod - def groupby(pairs, coord): - coord_getter = make_coord_getter(coord) + def groupby(pairs, coord_getter): pairs = sorted(pairs, key=coord_getter) return map(compose(list, second), groupby(pairs, coord_getter)) + @staticmethod + def other_axis(axis): + return "y" if axis == "x" else "x" + + def merge_along_axis(self, pairs, axis): + + c1_getter = make_coord_getter(f"{self.other_axis(axis)}1") + c2_getter = make_coord_getter(f"{self.other_axis(axis)}2") + group_merger = make_group_merger(axis) + + groups = self.groupby(pairs, c1_getter) + groups = chain.from_iterable(map(rpartial(self.groupby, c2_getter), groups)) + groups = map(group_merger, groups) + pairs = chain.from_iterable(groups) + + return pairs + def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]: pairs = list(pairs) @@ -177,18 +197,12 @@ class Stitcher: while True: - groups = self.groupby(pairs, "x1") - groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups)) - groups = map(merge_group_vertically, groups) - pairs = chain.from_iterable(groups) - - groups = self.groupby(pairs, "y1") - groups = chain.from_iterable(map(rpartial(self.groupby, "y2"), groups)) - groups = map(merge_group_horizontally, groups) - pairs = list(chain.from_iterable(groups)) + pairs = self.merge_along_axis(pairs, "x") + pairs = list(self.merge_along_axis(pairs, "y")) if len(pairs) == n: return pairs + n = len(pairs)