From 79cd31850db6ff872eafd7916ea8607dc01390d9 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Mon, 11 Apr 2022 16:47:47 +0200 Subject: [PATCH] fuzzy stitching WIP: added tolerance to stitching; added fuzzification function; added tests for grouping and (fuzzy and exact) --- image_prediction/stitching/grouping.py | 46 ++++++++++++++++++++++--- image_prediction/stitching/merging.py | 12 +++---- image_prediction/stitching/stitching.py | 6 ++-- test/unit_tests/image_stitching_test.py | 45 +++++++++++++++++++++--- test/utils/stitching.py | 23 ++++++++----- 5 files changed, 107 insertions(+), 25 deletions(-) diff --git a/image_prediction/stitching/grouping.py b/image_prediction/stitching/grouping.py index c6b0d18..26aceca 100644 --- a/image_prediction/stitching/grouping.py +++ b/image_prediction/stitching/grouping.py @@ -1,26 +1,64 @@ +from functools import lru_cache from itertools import groupby +import numpy as np from funcy import compose, second from image_prediction.stitching.utils import make_coord_getter class CoordGrouper: - def __init__(self, axis): + def __init__(self, axis, tolerance=0): self.c1_getter = make_coord_getter(f"{other_axis(axis)}1") self.c2_getter = make_coord_getter(f"{other_axis(axis)}2") + self.tolerance = tolerance def group_pairs_by_lesser_coordinate(self, pairs): - return group_by_coordinate(pairs, self.c1_getter) + return group_by_coordinate(pairs, self.c1_getter, self.tolerance) def group_pairs_by_greater_coordinate(self, pairs): - return group_by_coordinate(pairs, self.c2_getter) + return group_by_coordinate(pairs, self.c2_getter, self.tolerance) def other_axis(axis): return "y" if axis == "x" else "x" -def group_by_coordinate(pairs, coord_getter): +def fuzzify(func, tolerance): + def inner(item): + nonlocal mid_points + nonlocal lower_bounds + nonlocal upper_bounds + print(tolerance) + + value = func(item) + fits = (array(lower_bounds_array()) <= value) & (value <= array(upper_bounds_array())) + if any(fits): + return mid_points[np.argmax(fits)] + else: + mid_points = [*mid_points, value] + lower_bounds = [*lower_bounds, value - tolerance] + upper_bounds = [*upper_bounds, value + tolerance] + return value + + def lower_bounds_array(): + return tuple(lower_bounds) + + def upper_bounds_array(): + return tuple(upper_bounds) + + @lru_cache(maxsize=None) + def array(tpl): + return np.array(tpl) + + lower_bounds = [] + upper_bounds = [] + mid_points = [] + + return inner + + +def group_by_coordinate(pairs, coord_getter, tolerance=0): + coord_getter = fuzzify(coord_getter, tolerance) pairs = sorted(pairs, key=coord_getter) return map(compose(list, second), groupby(pairs, coord_getter)) diff --git a/image_prediction/stitching/merging.py b/image_prediction/stitching/merging.py index 5ec9d27..47c20cb 100644 --- a/image_prediction/stitching/merging.py +++ b/image_prediction/stitching/merging.py @@ -17,14 +17,14 @@ def no_new_merges(pairs1, pairs2): return len(pairs1) == len(pairs2) -def merge_along_both_axes(pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]: - pairs = merge_along_axis(pairs, "x") - pairs = list(merge_along_axis(pairs, "y")) +def merge_along_both_axes(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]: + pairs = merge_along_axis(pairs, "x", tolerance=tolerance) + pairs = list(merge_along_axis(pairs, "y", tolerance=tolerance)) return pairs -def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis) -> Iterable[ImageMetadataPair]: +def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis, tolerance=0) -> Iterable[ImageMetadataPair]: """Partially merges image-metadata pairs of adjacent images along a given axis. Needs to be iterated with alternating axes until no more merges happen to merge all adjacent images. @@ -41,13 +41,13 @@ def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis) -> Iterable[Image """ def group_pairs_within_groups_by_greater_coordinate(groups): - return map(CoordGrouper(axis).group_pairs_by_greater_coordinate, groups) + return map(CoordGrouper(axis, tolerance=tolerance).group_pairs_by_greater_coordinate, groups) def merge_groups_along_orthogonal_axis(groups): return map(make_group_merger(axis), groups) def group_pairs_by_lesser_coordinate(pairs): - return CoordGrouper(axis).group_pairs_by_lesser_coordinate(pairs) + return CoordGrouper(axis, tolerance=tolerance).group_pairs_by_lesser_coordinate(pairs) return rcompose( group_pairs_by_lesser_coordinate, diff --git a/image_prediction/stitching/stitching.py b/image_prediction/stitching/stitching.py index bb8c519..edc53de 100644 --- a/image_prediction/stitching/stitching.py +++ b/image_prediction/stitching/stitching.py @@ -1,11 +1,13 @@ from typing import Iterable +from funcy import rpartial + from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.stitching.merging import merge_along_both_axes, no_new_merges from image_prediction.utils.generic import until -def stitch_pairs(pairs: Iterable[ImageMetadataPair]) -> Iterable[ImageMetadataPair]: +def stitch_pairs(pairs: Iterable[ImageMetadataPair], tolerance=0) -> Iterable[ImageMetadataPair]: """Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent images.""" - return until(no_new_merges, merge_along_both_axes, pairs) + return until(no_new_merges, rpartial(merge_along_both_axes, tolerance), pairs) diff --git a/test/unit_tests/image_stitching_test.py b/test/unit_tests/image_stitching_test.py index e469e47..2c67485 100644 --- a/test/unit_tests/image_stitching_test.py +++ b/test/unit_tests/image_stitching_test.py @@ -1,6 +1,7 @@ from copy import deepcopy from functools import partial from itertools import starmap, repeat +from operator import itemgetter from typing import List import fpdf @@ -8,11 +9,12 @@ import numpy as np import pdf2image import pytest from PIL import Image -from funcy import merge, juxt, one +from funcy import merge, juxt, one, first from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info +from image_prediction.stitching.grouping import group_by_coordinate from image_prediction.stitching.stitching import stitch_pairs from image_prediction.stitching.utils import ( make_coord_getter, @@ -33,8 +35,33 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", width_getter, height_getter = map(make_length_getter, ("width", "height")) +# @pytest.mark.parametrize("noise", [(0, 3)]) +# @pytest.mark.parametrize("split_count", [2]) +# def test_image_stitcher_with_gaps(patch_image_metadata_pairs, base_patch_metadata, base_patch_image): +# pair_stitched = first(stitch_pairs(patch_image_metadata_pairs, tolerance=12)) +# pair_stitched.image.show() +# # base_patch_image.show() +# input() +# import IPython +# IPython.embed() +# assert pair_stitched.metadata == base_patch_metadata +# assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4) + + +def test_group_by_coordinate_exact(): + pairs = [(0, 1), (0, 3), (1, 4), (1, 4), (1, 2), (3, 3)] + pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=0)) + assert pairs_grouped == [[(0, 1), (0, 3)], [(1, 4), (1, 4), (1, 2)], [(3, 3)]] + + +def test_group_by_coordinate_fuzzy(): + pairs = [(0, 1), (1, 3), (1, 4), (2, 4), (2, 2), (3, 3)] + pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=1)) + assert pairs_grouped == [[(0, 1), (1, 3), (1, 4)], [(2, 4), (2, 2), (3, 3)]] + + def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image): - pair_stitched = stitch_pairs(patch_image_metadata_pairs)[0] + pair_stitched = first(stitch_pairs(patch_image_metadata_pairs)) assert pair_stitched.metadata == base_patch_metadata assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4) @@ -192,6 +219,16 @@ def base_patch_metadata(width, height, page_width, page_height): @pytest.fixture -def patches_metadata(base_patch_metadata): - patches_metadata = list(BoxSplitter().split_box(base_patch_metadata)) +def patches_metadata(base_patch_metadata, noise, split_count): + patches_metadata = list(BoxSplitter(noise).split_box(base_patch_metadata, split_count)) return patches_metadata + + +@pytest.fixture(params=[(0, 0)]) +def noise(request): + return request.param + + +@pytest.fixture(params=[5]) +def split_count(request): + return request.param diff --git a/test/utils/stitching.py b/test/utils/stitching.py index dc4cc0d..0546a58 100644 --- a/test/utils/stitching.py +++ b/test/utils/stitching.py @@ -2,14 +2,17 @@ import random from copy import deepcopy from itertools import chain -from funcy import rpartial, juxt +from funcy import rpartial, juxt, first from image_prediction.stitching.split_mapper import SplitMapper, HorizontalSplitMapper, VerticalSplitMapper class BoxSplitter: - def __init__(self): + def __init__(self, noise=None): self.__steps = None + self.__noise = (0, 0) if not noise else noise + if not min(self.__noise) >= 0: + raise ValueError("Noise interval must be non-negative.") def split_box(self, box, steps=5): self.__steps = steps @@ -51,22 +54,24 @@ class BoxSplitter: else self.__base_case(wrapped_box.wrapped) ) + def noise(self): + return int(random.uniform(*self.__noise)) + @staticmethod def __large_enough(wrapped_box: SplitMapper): return wrapped_box.dim >= 10 - @staticmethod - def __get_child_boxes(wrapped_box: SplitMapper): + def __get_child_boxes(self, wrapped_box: SplitMapper): split_len = random.randint(5, wrapped_box.dim - 5) - split_point = wrapped_box.c1 + split_len + split_point = wrapped_box.c1 + split_len + self.noise() box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box) - box_left.dim = split_len - box_right.dim = wrapped_box.dim - split_len + box_left.dim = split_len + self.noise() + box_right.dim = wrapped_box.dim - split_len + self.noise() - box_left.c2 = split_point - box_right.c1 = split_point + box_left.c2 = split_point + self.noise() + box_right.c1 = split_point + self.noise() return box_left.wrapped, box_right.wrapped