from functools import lru_cache from itertools import groupby import numpy as np from funcy import compose, second from image_prediction.stitching.utils import make_coord_getter class CoordGrouper: def __init__(self, axis, tolerance=0): self.c1_getter = make_coord_getter(f"{other_axis(axis)}1") self.c2_getter = make_coord_getter(f"{other_axis(axis)}2") self.tolerance = tolerance def group_pairs_by_lesser_coordinate(self, pairs): return group_by_coordinate(pairs, self.c1_getter, self.tolerance) def group_pairs_by_greater_coordinate(self, pairs): return group_by_coordinate(pairs, self.c2_getter, self.tolerance) def other_axis(axis): return "y" if axis == "x" else "x" def fuzzify(func, tolerance): def inner(item): nonlocal mid_points nonlocal lower_bounds nonlocal upper_bounds value = func(item) fits = (array(lower_bounds_array()) <= value) & (value <= array(upper_bounds_array())) if any(fits): return mid_points[np.argmax(fits)] else: mid_points = [*mid_points, value] lower_bounds = [*lower_bounds, value - tolerance] upper_bounds = [*upper_bounds, value + tolerance] return value def lower_bounds_array(): return tuple(lower_bounds) def upper_bounds_array(): return tuple(upper_bounds) @lru_cache(maxsize=None) def array(tpl): return np.array(tpl) lower_bounds = [] upper_bounds = [] mid_points = [] return inner def group_by_coordinate(pairs, coord_getter, tolerance=0): coord_getter = fuzzify(coord_getter, tolerance) pairs = sorted(pairs, key=coord_getter) return map(compose(list, second), groupby(pairs, coord_getter))