from itertools import groupby, chain from typing import Iterable, List from funcy import compose, second from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.stitcher.utils import make_coord_getter, make_group_merger from image_prediction.utils.generic import until_convergence class Stitcher: @staticmethod def groupby(pairs, coord_getter): pairs = sorted(pairs, key=coord_getter) return map(compose(list, second), groupby(pairs, coord_getter)) @staticmethod def other_axis(axis): return "y" if axis == "x" else "x" def merge_along_axis(self, pairs, axis): def group_pairs_by_c1(pairs): return self.groupby(pairs, c1_getter) def group_by_c2(pairs): return self.groupby(pairs, c2_getter) def group_pairs_within_groups_by_c2(groups): return map(group_by_c2, groups) def merge_groups_along_orthogonal_axis(groups): return map(group_merger, groups) c1_getter = make_coord_getter(f"{self.other_axis(axis)}1") c2_getter = make_coord_getter(f"{self.other_axis(axis)}2") group_merger = make_group_merger(axis) groups_of_pairs_with_same_c1 = group_pairs_by_c1(pairs) groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1) groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2) groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2) pairs = chain(*groups_of_merged_pairs) return pairs def merge_along_both_axes(self, pairs): pairs = self.merge_along_axis(pairs, "x") pairs = list(self.merge_along_axis(pairs, "y")) return pairs def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]: return until_convergence(self.merge_along_both_axes, pairs)