refactoring

This commit is contained in:
Matthias Bisping 2022-04-07 20:47:58 +02:00
parent 8ac9fcb19f
commit bb5db1b4ef

View File

@ -40,8 +40,12 @@ def make_coord_getter(c):
}[c] }[c]
def make_pair_merger(direction): def make_pair_merger(axis):
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[direction] return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
def make_group_merger(axis):
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
def make_length_getter(dim): def make_length_getter(dim):
@ -165,11 +169,27 @@ def merge_group_vertically(group):
class Stitcher: class Stitcher:
@staticmethod @staticmethod
def groupby(pairs, coord): def groupby(pairs, coord_getter):
coord_getter = make_coord_getter(coord)
pairs = sorted(pairs, key=coord_getter) pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter)) return map(compose(list, second), groupby(pairs, coord_getter))
@staticmethod
def other_axis(axis):
return "y" if axis == "x" else "x"
def merge_along_axis(self, pairs, axis):
c1_getter = make_coord_getter(f"{self.other_axis(axis)}1")
c2_getter = make_coord_getter(f"{self.other_axis(axis)}2")
group_merger = make_group_merger(axis)
groups = self.groupby(pairs, c1_getter)
groups = chain.from_iterable(map(rpartial(self.groupby, c2_getter), groups))
groups = map(group_merger, groups)
pairs = chain.from_iterable(groups)
return pairs
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]: def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
pairs = list(pairs) pairs = list(pairs)
@ -177,18 +197,12 @@ class Stitcher:
while True: while True:
groups = self.groupby(pairs, "x1") pairs = self.merge_along_axis(pairs, "x")
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups)) pairs = list(self.merge_along_axis(pairs, "y"))
groups = map(merge_group_vertically, groups)
pairs = chain.from_iterable(groups)
groups = self.groupby(pairs, "y1")
groups = chain.from_iterable(map(rpartial(self.groupby, "y2"), groups))
groups = map(merge_group_horizontally, groups)
pairs = list(chain.from_iterable(groups))
if len(pairs) == n: if len(pairs) == n:
return pairs return pairs
n = len(pairs) n = len(pairs)