fuzzy stitching WIP: added tolerance to stitching; added fuzzification function; added tests for grouping and (fuzzy and exact)
This commit is contained in:
parent
3d335783dc
commit
79cd31850d
@ -1,26 +1,64 @@
|
||||
from functools import lru_cache
|
||||
from itertools import groupby
|
||||
|
||||
import numpy as np
|
||||
from funcy import compose, second
|
||||
|
||||
from image_prediction.stitching.utils import make_coord_getter
|
||||
|
||||
|
||||
class CoordGrouper:
|
||||
def __init__(self, axis):
|
||||
def __init__(self, axis, tolerance=0):
|
||||
self.c1_getter = make_coord_getter(f"{other_axis(axis)}1")
|
||||
self.c2_getter = make_coord_getter(f"{other_axis(axis)}2")
|
||||
self.tolerance = tolerance
|
||||
|
||||
def group_pairs_by_lesser_coordinate(self, pairs):
|
||||
return group_by_coordinate(pairs, self.c1_getter)
|
||||
return group_by_coordinate(pairs, self.c1_getter, self.tolerance)
|
||||
|
||||
def group_pairs_by_greater_coordinate(self, pairs):
|
||||
return group_by_coordinate(pairs, self.c2_getter)
|
||||
return group_by_coordinate(pairs, self.c2_getter, self.tolerance)
|
||||
|
||||
|
||||
def other_axis(axis):
|
||||
return "y" if axis == "x" else "x"
|
||||
|
||||
|
||||
def group_by_coordinate(pairs, coord_getter):
|
||||
def fuzzify(func, tolerance):
|
||||
def inner(item):
|
||||
nonlocal mid_points
|
||||
nonlocal lower_bounds
|
||||
nonlocal upper_bounds
|
||||
print(tolerance)
|
||||
|
||||
value = func(item)
|
||||
fits = (array(lower_bounds_array()) <= value) & (value <= array(upper_bounds_array()))
|
||||
if any(fits):
|
||||
return mid_points[np.argmax(fits)]
|
||||
else:
|
||||
mid_points = [*mid_points, value]
|
||||
lower_bounds = [*lower_bounds, value - tolerance]
|
||||
upper_bounds = [*upper_bounds, value + tolerance]
|
||||
return value
|
||||
|
||||
def lower_bounds_array():
|
||||
return tuple(lower_bounds)
|
||||
|
||||
def upper_bounds_array():
|
||||
return tuple(upper_bounds)
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def array(tpl):
|
||||
return np.array(tpl)
|
||||
|
||||
lower_bounds = []
|
||||
upper_bounds = []
|
||||
mid_points = []
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def group_by_coordinate(pairs, coord_getter, tolerance=0):
|
||||
coord_getter = fuzzify(coord_getter, tolerance)
|
||||
pairs = sorted(pairs, key=coord_getter)
|
||||
return map(compose(list, second), groupby(pairs, coord_getter))
|
||||
|
||||
@ -17,14 +17,14 @@ def no_new_merges(pairs1, pairs2):
|
||||
return len(pairs1) == len(pairs2)
|
||||
|
||||
|
||||
def merge_along_both_axes(pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
|
||||
pairs = merge_along_axis(pairs, "x")
|
||||
pairs = list(merge_along_axis(pairs, "y"))
|
||||
def merge_along_both_axes(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]:
|
||||
pairs = merge_along_axis(pairs, "x", tolerance=tolerance)
|
||||
pairs = list(merge_along_axis(pairs, "y", tolerance=tolerance))
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis) -> Iterable[ImageMetadataPair]:
|
||||
def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis, tolerance=0) -> Iterable[ImageMetadataPair]:
|
||||
"""Partially merges image-metadata pairs of adjacent images along a given axis. Needs to be iterated with
|
||||
alternating axes until no more merges happen to merge all adjacent images.
|
||||
|
||||
@ -41,13 +41,13 @@ def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis) -> Iterable[Image
|
||||
"""
|
||||
|
||||
def group_pairs_within_groups_by_greater_coordinate(groups):
|
||||
return map(CoordGrouper(axis).group_pairs_by_greater_coordinate, groups)
|
||||
return map(CoordGrouper(axis, tolerance=tolerance).group_pairs_by_greater_coordinate, groups)
|
||||
|
||||
def merge_groups_along_orthogonal_axis(groups):
|
||||
return map(make_group_merger(axis), groups)
|
||||
|
||||
def group_pairs_by_lesser_coordinate(pairs):
|
||||
return CoordGrouper(axis).group_pairs_by_lesser_coordinate(pairs)
|
||||
return CoordGrouper(axis, tolerance=tolerance).group_pairs_by_lesser_coordinate(pairs)
|
||||
|
||||
return rcompose(
|
||||
group_pairs_by_lesser_coordinate,
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from typing import Iterable
|
||||
|
||||
from funcy import rpartial
|
||||
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.stitching.merging import merge_along_both_axes, no_new_merges
|
||||
from image_prediction.utils.generic import until
|
||||
|
||||
|
||||
def stitch_pairs(pairs: Iterable[ImageMetadataPair]) -> Iterable[ImageMetadataPair]:
|
||||
def stitch_pairs(pairs: Iterable[ImageMetadataPair], tolerance=0) -> Iterable[ImageMetadataPair]:
|
||||
"""Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent
|
||||
images."""
|
||||
return until(no_new_merges, merge_along_both_axes, pairs)
|
||||
return until(no_new_merges, rpartial(merge_along_both_axes, tolerance), pairs)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from itertools import starmap, repeat
|
||||
from operator import itemgetter
|
||||
from typing import List
|
||||
|
||||
import fpdf
|
||||
@ -8,11 +9,12 @@ import numpy as np
|
||||
import pdf2image
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from funcy import merge, juxt, one
|
||||
from funcy import merge, juxt, one, first
|
||||
|
||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.stitching.grouping import group_by_coordinate
|
||||
from image_prediction.stitching.stitching import stitch_pairs
|
||||
from image_prediction.stitching.utils import (
|
||||
make_coord_getter,
|
||||
@ -33,8 +35,33 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1",
|
||||
width_getter, height_getter = map(make_length_getter, ("width", "height"))
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("noise", [(0, 3)])
|
||||
# @pytest.mark.parametrize("split_count", [2])
|
||||
# def test_image_stitcher_with_gaps(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
|
||||
# pair_stitched = first(stitch_pairs(patch_image_metadata_pairs, tolerance=12))
|
||||
# pair_stitched.image.show()
|
||||
# # base_patch_image.show()
|
||||
# input()
|
||||
# import IPython
|
||||
# IPython.embed()
|
||||
# assert pair_stitched.metadata == base_patch_metadata
|
||||
# assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
||||
|
||||
|
||||
def test_group_by_coordinate_exact():
|
||||
pairs = [(0, 1), (0, 3), (1, 4), (1, 4), (1, 2), (3, 3)]
|
||||
pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=0))
|
||||
assert pairs_grouped == [[(0, 1), (0, 3)], [(1, 4), (1, 4), (1, 2)], [(3, 3)]]
|
||||
|
||||
|
||||
def test_group_by_coordinate_fuzzy():
|
||||
pairs = [(0, 1), (1, 3), (1, 4), (2, 4), (2, 2), (3, 3)]
|
||||
pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=1))
|
||||
assert pairs_grouped == [[(0, 1), (1, 3), (1, 4)], [(2, 4), (2, 2), (3, 3)]]
|
||||
|
||||
|
||||
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
|
||||
pair_stitched = stitch_pairs(patch_image_metadata_pairs)[0]
|
||||
pair_stitched = first(stitch_pairs(patch_image_metadata_pairs))
|
||||
assert pair_stitched.metadata == base_patch_metadata
|
||||
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
||||
|
||||
@ -192,6 +219,16 @@ def base_patch_metadata(width, height, page_width, page_height):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patches_metadata(base_patch_metadata):
|
||||
patches_metadata = list(BoxSplitter().split_box(base_patch_metadata))
|
||||
def patches_metadata(base_patch_metadata, noise, split_count):
|
||||
patches_metadata = list(BoxSplitter(noise).split_box(base_patch_metadata, split_count))
|
||||
return patches_metadata
|
||||
|
||||
|
||||
@pytest.fixture(params=[(0, 0)])
|
||||
def noise(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[5])
|
||||
def split_count(request):
|
||||
return request.param
|
||||
|
||||
@ -2,14 +2,17 @@ import random
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
|
||||
from funcy import rpartial, juxt
|
||||
from funcy import rpartial, juxt, first
|
||||
|
||||
from image_prediction.stitching.split_mapper import SplitMapper, HorizontalSplitMapper, VerticalSplitMapper
|
||||
|
||||
|
||||
class BoxSplitter:
|
||||
def __init__(self):
|
||||
def __init__(self, noise=None):
|
||||
self.__steps = None
|
||||
self.__noise = (0, 0) if not noise else noise
|
||||
if not min(self.__noise) >= 0:
|
||||
raise ValueError("Noise interval must be non-negative.")
|
||||
|
||||
def split_box(self, box, steps=5):
|
||||
self.__steps = steps
|
||||
@ -51,22 +54,24 @@ class BoxSplitter:
|
||||
else self.__base_case(wrapped_box.wrapped)
|
||||
)
|
||||
|
||||
def noise(self):
|
||||
return int(random.uniform(*self.__noise))
|
||||
|
||||
@staticmethod
|
||||
def __large_enough(wrapped_box: SplitMapper):
|
||||
return wrapped_box.dim >= 10
|
||||
|
||||
@staticmethod
|
||||
def __get_child_boxes(wrapped_box: SplitMapper):
|
||||
def __get_child_boxes(self, wrapped_box: SplitMapper):
|
||||
|
||||
split_len = random.randint(5, wrapped_box.dim - 5)
|
||||
split_point = wrapped_box.c1 + split_len
|
||||
split_point = wrapped_box.c1 + split_len + self.noise()
|
||||
|
||||
box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box)
|
||||
|
||||
box_left.dim = split_len
|
||||
box_right.dim = wrapped_box.dim - split_len
|
||||
box_left.dim = split_len + self.noise()
|
||||
box_right.dim = wrapped_box.dim - split_len + self.noise()
|
||||
|
||||
box_left.c2 = split_point
|
||||
box_right.c1 = split_point
|
||||
box_left.c2 = split_point + self.noise()
|
||||
box_right.c1 = split_point + self.noise()
|
||||
|
||||
return box_left.wrapped, box_right.wrapped
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user