fuzzy stitching WIP: added tolerance to stitching; added fuzzification function; added tests for grouping and (fuzzy and exact)

This commit is contained in:
Matthias Bisping 2022-04-11 16:47:47 +02:00
parent 3d335783dc
commit 79cd31850d
5 changed files with 107 additions and 25 deletions

View File

@ -1,26 +1,64 @@
from functools import lru_cache
from itertools import groupby
import numpy as np
from funcy import compose, second
from image_prediction.stitching.utils import make_coord_getter
class CoordGrouper:
def __init__(self, axis):
def __init__(self, axis, tolerance=0):
self.c1_getter = make_coord_getter(f"{other_axis(axis)}1")
self.c2_getter = make_coord_getter(f"{other_axis(axis)}2")
self.tolerance = tolerance
def group_pairs_by_lesser_coordinate(self, pairs):
return group_by_coordinate(pairs, self.c1_getter)
return group_by_coordinate(pairs, self.c1_getter, self.tolerance)
def group_pairs_by_greater_coordinate(self, pairs):
return group_by_coordinate(pairs, self.c2_getter)
return group_by_coordinate(pairs, self.c2_getter, self.tolerance)
def other_axis(axis):
return "y" if axis == "x" else "x"
def group_by_coordinate(pairs, coord_getter):
def fuzzify(func, tolerance):
def inner(item):
nonlocal mid_points
nonlocal lower_bounds
nonlocal upper_bounds
print(tolerance)
value = func(item)
fits = (array(lower_bounds_array()) <= value) & (value <= array(upper_bounds_array()))
if any(fits):
return mid_points[np.argmax(fits)]
else:
mid_points = [*mid_points, value]
lower_bounds = [*lower_bounds, value - tolerance]
upper_bounds = [*upper_bounds, value + tolerance]
return value
def lower_bounds_array():
return tuple(lower_bounds)
def upper_bounds_array():
return tuple(upper_bounds)
@lru_cache(maxsize=None)
def array(tpl):
return np.array(tpl)
lower_bounds = []
upper_bounds = []
mid_points = []
return inner
def group_by_coordinate(pairs, coord_getter, tolerance=0):
coord_getter = fuzzify(coord_getter, tolerance)
pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter))

View File

@ -17,14 +17,14 @@ def no_new_merges(pairs1, pairs2):
return len(pairs1) == len(pairs2)
def merge_along_both_axes(pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
pairs = merge_along_axis(pairs, "x")
pairs = list(merge_along_axis(pairs, "y"))
def merge_along_both_axes(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]:
pairs = merge_along_axis(pairs, "x", tolerance=tolerance)
pairs = list(merge_along_axis(pairs, "y", tolerance=tolerance))
return pairs
def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis) -> Iterable[ImageMetadataPair]:
def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis, tolerance=0) -> Iterable[ImageMetadataPair]:
"""Partially merges image-metadata pairs of adjacent images along a given axis. Needs to be iterated with
alternating axes until no more merges happen to merge all adjacent images.
@ -41,13 +41,13 @@ def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis) -> Iterable[Image
"""
def group_pairs_within_groups_by_greater_coordinate(groups):
return map(CoordGrouper(axis).group_pairs_by_greater_coordinate, groups)
return map(CoordGrouper(axis, tolerance=tolerance).group_pairs_by_greater_coordinate, groups)
def merge_groups_along_orthogonal_axis(groups):
return map(make_group_merger(axis), groups)
def group_pairs_by_lesser_coordinate(pairs):
return CoordGrouper(axis).group_pairs_by_lesser_coordinate(pairs)
return CoordGrouper(axis, tolerance=tolerance).group_pairs_by_lesser_coordinate(pairs)
return rcompose(
group_pairs_by_lesser_coordinate,

View File

@ -1,11 +1,13 @@
from typing import Iterable
from funcy import rpartial
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.stitching.merging import merge_along_both_axes, no_new_merges
from image_prediction.utils.generic import until
def stitch_pairs(pairs: Iterable[ImageMetadataPair]) -> Iterable[ImageMetadataPair]:
def stitch_pairs(pairs: Iterable[ImageMetadataPair], tolerance=0) -> Iterable[ImageMetadataPair]:
"""Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent
images."""
return until(no_new_merges, merge_along_both_axes, pairs)
return until(no_new_merges, rpartial(merge_along_both_axes, tolerance), pairs)

View File

@ -1,6 +1,7 @@
from copy import deepcopy
from functools import partial
from itertools import starmap, repeat
from operator import itemgetter
from typing import List
import fpdf
@ -8,11 +9,12 @@ import numpy as np
import pdf2image
import pytest
from PIL import Image
from funcy import merge, juxt, one
from funcy import merge, juxt, one, first
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.stitching.grouping import group_by_coordinate
from image_prediction.stitching.stitching import stitch_pairs
from image_prediction.stitching.utils import (
make_coord_getter,
@ -33,8 +35,33 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1",
width_getter, height_getter = map(make_length_getter, ("width", "height"))
# @pytest.mark.parametrize("noise", [(0, 3)])
# @pytest.mark.parametrize("split_count", [2])
# def test_image_stitcher_with_gaps(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
# pair_stitched = first(stitch_pairs(patch_image_metadata_pairs, tolerance=12))
# pair_stitched.image.show()
# # base_patch_image.show()
# input()
# import IPython
# IPython.embed()
# assert pair_stitched.metadata == base_patch_metadata
# assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
def test_group_by_coordinate_exact():
pairs = [(0, 1), (0, 3), (1, 4), (1, 4), (1, 2), (3, 3)]
pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=0))
assert pairs_grouped == [[(0, 1), (0, 3)], [(1, 4), (1, 4), (1, 2)], [(3, 3)]]
def test_group_by_coordinate_fuzzy():
pairs = [(0, 1), (1, 3), (1, 4), (2, 4), (2, 2), (3, 3)]
pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=1))
assert pairs_grouped == [[(0, 1), (1, 3), (1, 4)], [(2, 4), (2, 2), (3, 3)]]
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
pair_stitched = stitch_pairs(patch_image_metadata_pairs)[0]
pair_stitched = first(stitch_pairs(patch_image_metadata_pairs))
assert pair_stitched.metadata == base_patch_metadata
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
@ -192,6 +219,16 @@ def base_patch_metadata(width, height, page_width, page_height):
@pytest.fixture
def patches_metadata(base_patch_metadata):
patches_metadata = list(BoxSplitter().split_box(base_patch_metadata))
def patches_metadata(base_patch_metadata, noise, split_count):
patches_metadata = list(BoxSplitter(noise).split_box(base_patch_metadata, split_count))
return patches_metadata
@pytest.fixture(params=[(0, 0)])
def noise(request):
return request.param
@pytest.fixture(params=[5])
def split_count(request):
return request.param

View File

@ -2,14 +2,17 @@ import random
from copy import deepcopy
from itertools import chain
from funcy import rpartial, juxt
from funcy import rpartial, juxt, first
from image_prediction.stitching.split_mapper import SplitMapper, HorizontalSplitMapper, VerticalSplitMapper
class BoxSplitter:
def __init__(self):
def __init__(self, noise=None):
self.__steps = None
self.__noise = (0, 0) if not noise else noise
if not min(self.__noise) >= 0:
raise ValueError("Noise interval must be non-negative.")
def split_box(self, box, steps=5):
self.__steps = steps
@ -51,22 +54,24 @@ class BoxSplitter:
else self.__base_case(wrapped_box.wrapped)
)
def noise(self):
return int(random.uniform(*self.__noise))
@staticmethod
def __large_enough(wrapped_box: SplitMapper):
return wrapped_box.dim >= 10
@staticmethod
def __get_child_boxes(wrapped_box: SplitMapper):
def __get_child_boxes(self, wrapped_box: SplitMapper):
split_len = random.randint(5, wrapped_box.dim - 5)
split_point = wrapped_box.c1 + split_len
split_point = wrapped_box.c1 + split_len + self.noise()
box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box)
box_left.dim = split_len
box_right.dim = wrapped_box.dim - split_len
box_left.dim = split_len + self.noise()
box_right.dim = wrapped_box.dim - split_len + self.noise()
box_left.c2 = split_point
box_right.c1 = split_point
box_left.c2 = split_point + self.noise()
box_right.c1 = split_point + self.noise()
return box_left.wrapped, box_right.wrapped