diff --git a/image_prediction/stitcher/stitcher.py b/image_prediction/stitcher/stitcher.py index 605790e..441ea2e 100644 --- a/image_prediction/stitcher/stitcher.py +++ b/image_prediction/stitcher/stitcher.py @@ -8,34 +8,43 @@ from image_prediction.stitcher.utils import make_coord_getter, make_group_merger from image_prediction.utils.generic import until +def group_by_coordinate(pairs, coord_getter): + pairs = sorted(pairs, key=coord_getter) + return map(compose(list, second), groupby(pairs, coord_getter)) + + +class CoordGrouper: + def __init__(self, axis): + + self.c1_getter = make_coord_getter(f"{other_axis(axis)}1") + self.c2_getter = make_coord_getter(f"{other_axis(axis)}2") + + def get_c1(self, pair: ImageMetadataPair): + return self.c1_getter(pair) + + def group_pairs_by_c1(self, pairs): + return group_by_coordinate(pairs, self.c1_getter) + + def group_pairs_by_c2(self, pairs): + return group_by_coordinate(pairs, self.c2_getter) + + +def other_axis(axis): + return "y" if axis == "x" else "x" + + class Stitcher: - @staticmethod - def groupby(pairs, coord_getter): - pairs = sorted(pairs, key=coord_getter) - return map(compose(list, second), groupby(pairs, coord_getter)) - - @staticmethod - def other_axis(axis): - return "y" if axis == "x" else "x" - def merge_along_axis(self, pairs, axis): - def group_pairs_by_c1(pairs): - return self.groupby(pairs, c1_getter) - - def group_by_c2(pairs): - return self.groupby(pairs, c2_getter) - def group_pairs_within_groups_by_c2(groups): - return map(group_by_c2, groups) + return map(grouper.group_pairs_by_c2, groups) def merge_groups_along_orthogonal_axis(groups): return map(group_merger, groups) - c1_getter = make_coord_getter(f"{self.other_axis(axis)}1") - c2_getter = make_coord_getter(f"{self.other_axis(axis)}2") + grouper = CoordGrouper(axis) group_merger = make_group_merger(axis) - groups_of_pairs_with_same_c1 = group_pairs_by_c1(pairs) + groups_of_pairs_with_same_c1 = grouper.group_pairs_by_c1(pairs) groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1) groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2) groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2)