fuzzy stitching WIP: mostly works, but sometimes fails. run test_image_stitcher_with_gaps to debug

This commit is contained in:
Matthias Bisping 2022-04-11 19:20:47 +02:00
parent 79cd31850d
commit bb7c1be630
7 changed files with 119 additions and 47 deletions

View File

@ -29,7 +29,6 @@ def fuzzify(func, tolerance):
nonlocal mid_points
nonlocal lower_bounds
nonlocal upper_bounds
print(tolerance)
value = func(item)
fits = (array(lower_bounds_array()) <= value) & (value <= array(upper_bounds_array()))

View File

@ -3,7 +3,7 @@ from functools import reduce
from typing import Iterable, Callable, List
from PIL import Image
from funcy import juxt, first, rest, rcompose
from funcy import juxt, first, rest, rcompose, rpartial
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
@ -44,7 +44,7 @@ def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis, tolerance=0) -> I
return map(CoordGrouper(axis, tolerance=tolerance).group_pairs_by_greater_coordinate, groups)
def merge_groups_along_orthogonal_axis(groups):
return map(make_group_merger(axis), groups)
return map(rpartial(make_group_merger(axis), tolerance), groups)
def group_pairs_by_lesser_coordinate(pairs):
return CoordGrouper(axis, tolerance=tolerance).group_pairs_by_lesser_coordinate(pairs)
@ -62,29 +62,33 @@ def make_group_merger(axis):
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
def merge_group_vertically(group: Iterable[ImageMetadataPair]):
return merge_group(group, "y")
def merge_group_vertically(group: Iterable[ImageMetadataPair], tolerance=0):
return merge_group(group, "y", tolerance=tolerance)
def merge_group_horizontally(group: Iterable[ImageMetadataPair]):
return merge_group(group, "x")
def merge_group_horizontally(group: Iterable[ImageMetadataPair], tolerance=0):
return merge_group(group, "x", tolerance=tolerance)
def merge_group(group: Iterable[ImageMetadataPair], direction):
reduce_group = make_merger_aggregator(direction)
def merge_group(group: Iterable[ImageMetadataPair], direction, tolerance=0):
reduce_group = make_merger_aggregator(direction, tolerance=tolerance)
return until(no_new_merges, reduce_group, group)
def make_merger_aggregator(axis) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:
def make_merger_aggregator(axis, tolerance=0) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
head H and aggregates non-adjacent in the tail T.
Note:
When tolerance > 0, the bounding box of the merged image no longer matches the bounding box of the mereged
metadata. This is intended behaviour, but might be not be expected by the caller.
"""
def merger_aggregator(pairs: Iterable[ImageMetadataPair]):
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
if c2_getter(aggr) == c1_getter(pair):
if abs(c2_getter(aggr) - c1_getter(pair)) <= tolerance:
aggr = pair_merger(aggr, pair)
return aggr, *non_aggr
else:
@ -96,6 +100,8 @@ def make_merger_aggregator(axis) -> Callable[[Iterable[ImageMetadataPair]], Iter
head_pair, pairs = juxt(first, rest)(pairs)
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [head_pair]))
assert tolerance >= 0
c1_getter = make_coord_getter(f"{axis}1")
c2_getter = make_coord_getter(f"{axis}2")
pair_merger = make_pair_merger(axis)

View File

@ -1,4 +1,4 @@
from typing import Iterable
from typing import Iterable, List
from funcy import rpartial
@ -7,7 +7,7 @@ from image_prediction.stitching.merging import merge_along_both_axes, no_new_mer
from image_prediction.utils.generic import until
def stitch_pairs(pairs: Iterable[ImageMetadataPair], tolerance=0) -> Iterable[ImageMetadataPair]:
def stitch_pairs(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]:
"""Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent
images."""
return until(no_new_merges, rpartial(merge_along_both_axes, tolerance), pairs)

View File

@ -12,7 +12,7 @@ import fpdf
import numpy as np
import pytest
from PIL import Image
from funcy import rcompose
from funcy import rcompose, merge
from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
@ -460,6 +460,13 @@ def get_base_position_metadata(width, height, page_width, page_height):
}
@pytest.fixture
def base_patch_metadata(width, height, page_width, page_height):
metadata = get_base_position_metadata(width, height, page_width, page_height)
metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height})
return metadata
@pytest.fixture(params=[33, 100])
def height(request):
return request.param

View File

@ -1,4 +1,5 @@
from copy import deepcopy
from copy import deepcopy
from functools import partial
from itertools import starmap, repeat
from operator import itemgetter
@ -9,22 +10,28 @@ import numpy as np
import pdf2image
import pytest
from PIL import Image
from funcy import merge, juxt, one, first
from funcy import juxt, one, first
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.stitching.grouping import group_by_coordinate
from image_prediction.stitching.merging import (
merge_metadata_horizontally,
merge_metadata_vertically,
merge_pair_horizontally,
merge_pair_vertically,
concat_images_horizontally,
concat_images_vertically,
merge_group_horizontally,
merge_group_vertically,
)
from image_prediction.stitching.stitching import stitch_pairs
from image_prediction.stitching.utils import (
make_coord_getter,
make_length_getter,
)
from image_prediction.stitching.merging import merge_metadata_horizontally, merge_metadata_vertically, \
merge_pair_horizontally, merge_pair_vertically, concat_images_horizontally, concat_images_vertically, \
merge_group_horizontally, merge_group_vertically
from test.conftest import (
get_base_position_metadata,
add_image,
random_single_color_image_from_metadata,
random_size_gray_image_from_metadata,
@ -35,19 +42,6 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1",
width_getter, height_getter = map(make_length_getter, ("width", "height"))
# @pytest.mark.parametrize("noise", [(0, 3)])
# @pytest.mark.parametrize("split_count", [2])
# def test_image_stitcher_with_gaps(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
# pair_stitched = first(stitch_pairs(patch_image_metadata_pairs, tolerance=12))
# pair_stitched.image.show()
# # base_patch_image.show()
# input()
# import IPython
# IPython.embed()
# assert pair_stitched.metadata == base_patch_metadata
# assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
def test_group_by_coordinate_exact():
pairs = [(0, 1), (0, 3), (1, 4), (1, 4), (1, 2), (3, 3)]
pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=0))
@ -61,11 +55,35 @@ def test_group_by_coordinate_fuzzy():
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
pair_stitched = first(stitch_pairs(patch_image_metadata_pairs))
pairs_stitched = stitch_pairs(patch_image_metadata_pairs)
pair_stitched = first(pairs_stitched)
assert len(pairs_stitched) == 1
assert pair_stitched.metadata == base_patch_metadata
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
@pytest.mark.parametrize("noise", [(0, 2)])
@pytest.mark.parametrize("split_count", [5])
@pytest.mark.parametrize("width", [100])
@pytest.mark.parametrize("height", [100])
@pytest.mark.parametrize("page_width", [100])
@pytest.mark.parametrize("page_height", [100])
@pytest.mark.parametrize("execution_number", range(100))
def test_image_stitcher_with_gaps(patch_image_metadata_pairs, base_patch_metadata, base_patch_image, execution_number):
print(len(patch_image_metadata_pairs))
pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=7)
try:
assert len(pairs_stitched) == 1
except:
for p in pairs_stitched:
p.image.show()
base_patch_image.show()
import IPython
IPython.embed()
def test_merge_group_horizontally(horizontal_merge_test_pairs):
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
@ -211,13 +229,6 @@ def patch_image_metadata_pairs(patches_metadata) -> List[ImageMetadataPair]:
return list(starmap(ImageMetadataPair, zip(images, patches_metadata)))
@pytest.fixture
def base_patch_metadata(width, height, page_width, page_height):
metadata = get_base_position_metadata(width, height, page_width, page_height)
metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height})
return metadata
@pytest.fixture
def patches_metadata(base_patch_metadata, noise, split_count):
patches_metadata = list(BoxSplitter(noise).split_box(base_patch_metadata, split_count))

View File

@ -0,0 +1,50 @@
from image_prediction.info import Info
from image_prediction.stitching.split_mapper import VerticalSplitMapper, HorizontalSplitMapper
def test_split_vertical_mapper(base_patch_metadata):
sm = VerticalSplitMapper(base_patch_metadata)
sm.c1 += 10 + 3
sm.c2 += 20 + 3
sm.dim += 20 + 3
smw = sm.wrapped
assert smw[Info.Y1] == sm.c1 == base_patch_metadata[Info.Y1] + 10 + 3
assert smw[Info.Y2] == sm.c2 == base_patch_metadata[Info.Y2] + 20 + 3
assert smw[Info.HEIGHT] == sm.dim == base_patch_metadata[Info.HEIGHT] + 20 + 3
sm = VerticalSplitMapper(base_patch_metadata)
sm.c1 = 10 + 3
sm.c2 = 20 + 3
sm.dim = 20 + 3
smw = sm.wrapped
assert smw[Info.Y1] == sm.c1 == 10 + 3
assert smw[Info.Y2] == sm.c2 == 20 + 3
assert smw[Info.HEIGHT] == sm.dim == 20 + 3
def test_split_horizontal_mapper(base_patch_metadata):
sm = HorizontalSplitMapper(base_patch_metadata)
sm.c1 += 10 + 3
sm.c2 += 20 + 3
sm.dim += 20 + 3
smw = sm.wrapped
assert smw[Info.X1] == sm.c1 == base_patch_metadata[Info.X1] + 10 + 3
assert smw[Info.X2] == sm.c2 == base_patch_metadata[Info.X2] + 20 + 3
assert smw[Info.WIDTH] == sm.dim == base_patch_metadata[Info.WIDTH] + 20 + 3
sm = HorizontalSplitMapper(base_patch_metadata)
sm.c1 = 10 + 3
sm.c2 = 20 + 3
sm.dim = 20 + 3
smw = sm.wrapped
assert smw[Info.X1] == sm.c1 == 10 + 3
assert smw[Info.X2] == sm.c2 == 20 + 3
assert smw[Info.WIDTH] == sm.dim == 20 + 3

View File

@ -11,8 +11,6 @@ class BoxSplitter:
def __init__(self, noise=None):
self.__steps = None
self.__noise = (0, 0) if not noise else noise
if not min(self.__noise) >= 0:
raise ValueError("Noise interval must be non-negative.")
def split_box(self, box, steps=5):
self.__steps = steps
@ -55,7 +53,7 @@ class BoxSplitter:
)
def noise(self):
return int(random.uniform(*self.__noise))
return int(round(random.uniform(*self.__noise)))
@staticmethod
def __large_enough(wrapped_box: SplitMapper):
@ -64,14 +62,15 @@ class BoxSplitter:
def __get_child_boxes(self, wrapped_box: SplitMapper):
split_len = random.randint(5, wrapped_box.dim - 5)
split_point = wrapped_box.c1 + split_len + self.noise()
split_point = wrapped_box.c1 + split_len
box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box)
box_left.dim = split_len + self.noise()
box_right.dim = wrapped_box.dim - split_len + self.noise()
noise = - self.noise()
box_left.dim = split_len + noise
box_right.dim = wrapped_box.dim - split_len
box_left.c2 = split_point + self.noise()
box_left.c2 = split_point + noise
box_right.c1 = split_point + self.noise()
return box_left.wrapped, box_right.wrapped