fuzzy stitching WIP: mostly works, but sometimes fails. run test_image_stitcher_with_gaps to debug
This commit is contained in:
parent
79cd31850d
commit
bb7c1be630
@ -29,7 +29,6 @@ def fuzzify(func, tolerance):
|
||||
nonlocal mid_points
|
||||
nonlocal lower_bounds
|
||||
nonlocal upper_bounds
|
||||
print(tolerance)
|
||||
|
||||
value = func(item)
|
||||
fits = (array(lower_bounds_array()) <= value) & (value <= array(upper_bounds_array()))
|
||||
|
||||
@ -3,7 +3,7 @@ from functools import reduce
|
||||
from typing import Iterable, Callable, List
|
||||
|
||||
from PIL import Image
|
||||
from funcy import juxt, first, rest, rcompose
|
||||
from funcy import juxt, first, rest, rcompose, rpartial
|
||||
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
@ -44,7 +44,7 @@ def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis, tolerance=0) -> I
|
||||
return map(CoordGrouper(axis, tolerance=tolerance).group_pairs_by_greater_coordinate, groups)
|
||||
|
||||
def merge_groups_along_orthogonal_axis(groups):
|
||||
return map(make_group_merger(axis), groups)
|
||||
return map(rpartial(make_group_merger(axis), tolerance), groups)
|
||||
|
||||
def group_pairs_by_lesser_coordinate(pairs):
|
||||
return CoordGrouper(axis, tolerance=tolerance).group_pairs_by_lesser_coordinate(pairs)
|
||||
@ -62,29 +62,33 @@ def make_group_merger(axis):
|
||||
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
|
||||
|
||||
|
||||
def merge_group_vertically(group: Iterable[ImageMetadataPair]):
|
||||
return merge_group(group, "y")
|
||||
def merge_group_vertically(group: Iterable[ImageMetadataPair], tolerance=0):
|
||||
return merge_group(group, "y", tolerance=tolerance)
|
||||
|
||||
|
||||
def merge_group_horizontally(group: Iterable[ImageMetadataPair]):
|
||||
return merge_group(group, "x")
|
||||
def merge_group_horizontally(group: Iterable[ImageMetadataPair], tolerance=0):
|
||||
return merge_group(group, "x", tolerance=tolerance)
|
||||
|
||||
|
||||
def merge_group(group: Iterable[ImageMetadataPair], direction):
|
||||
reduce_group = make_merger_aggregator(direction)
|
||||
def merge_group(group: Iterable[ImageMetadataPair], direction, tolerance=0):
|
||||
reduce_group = make_merger_aggregator(direction, tolerance=tolerance)
|
||||
return until(no_new_merges, reduce_group, group)
|
||||
|
||||
|
||||
def make_merger_aggregator(axis) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:
|
||||
def make_merger_aggregator(axis, tolerance=0) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:
|
||||
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
||||
head H and aggregates non-adjacent in the tail T.
|
||||
|
||||
Note:
|
||||
When tolerance > 0, the bounding box of the merged image no longer matches the bounding box of the mereged
|
||||
metadata. This is intended behaviour, but might be not be expected by the caller.
|
||||
"""
|
||||
|
||||
def merger_aggregator(pairs: Iterable[ImageMetadataPair]):
|
||||
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
||||
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
||||
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
||||
if c2_getter(aggr) == c1_getter(pair):
|
||||
if abs(c2_getter(aggr) - c1_getter(pair)) <= tolerance:
|
||||
aggr = pair_merger(aggr, pair)
|
||||
return aggr, *non_aggr
|
||||
else:
|
||||
@ -96,6 +100,8 @@ def make_merger_aggregator(axis) -> Callable[[Iterable[ImageMetadataPair]], Iter
|
||||
head_pair, pairs = juxt(first, rest)(pairs)
|
||||
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [head_pair]))
|
||||
|
||||
assert tolerance >= 0
|
||||
|
||||
c1_getter = make_coord_getter(f"{axis}1")
|
||||
c2_getter = make_coord_getter(f"{axis}2")
|
||||
pair_merger = make_pair_merger(axis)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Iterable
|
||||
from typing import Iterable, List
|
||||
|
||||
from funcy import rpartial
|
||||
|
||||
@ -7,7 +7,7 @@ from image_prediction.stitching.merging import merge_along_both_axes, no_new_mer
|
||||
from image_prediction.utils.generic import until
|
||||
|
||||
|
||||
def stitch_pairs(pairs: Iterable[ImageMetadataPair], tolerance=0) -> Iterable[ImageMetadataPair]:
|
||||
def stitch_pairs(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]:
|
||||
"""Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent
|
||||
images."""
|
||||
return until(no_new_merges, rpartial(merge_along_both_axes, tolerance), pairs)
|
||||
|
||||
@ -12,7 +12,7 @@ import fpdf
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from funcy import rcompose
|
||||
from funcy import rcompose, merge
|
||||
|
||||
from image_prediction.classifier.classifier import Classifier
|
||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
@ -460,6 +460,13 @@ def get_base_position_metadata(width, height, page_width, page_height):
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_patch_metadata(width, height, page_width, page_height):
|
||||
metadata = get_base_position_metadata(width, height, page_width, page_height)
|
||||
metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height})
|
||||
return metadata
|
||||
|
||||
|
||||
@pytest.fixture(params=[33, 100])
|
||||
def height(request):
|
||||
return request.param
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from copy import deepcopy
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from itertools import starmap, repeat
|
||||
from operator import itemgetter
|
||||
@ -9,22 +10,28 @@ import numpy as np
|
||||
import pdf2image
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from funcy import merge, juxt, one, first
|
||||
from funcy import juxt, one, first
|
||||
|
||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.stitching.grouping import group_by_coordinate
|
||||
from image_prediction.stitching.merging import (
|
||||
merge_metadata_horizontally,
|
||||
merge_metadata_vertically,
|
||||
merge_pair_horizontally,
|
||||
merge_pair_vertically,
|
||||
concat_images_horizontally,
|
||||
concat_images_vertically,
|
||||
merge_group_horizontally,
|
||||
merge_group_vertically,
|
||||
)
|
||||
from image_prediction.stitching.stitching import stitch_pairs
|
||||
from image_prediction.stitching.utils import (
|
||||
make_coord_getter,
|
||||
make_length_getter,
|
||||
)
|
||||
from image_prediction.stitching.merging import merge_metadata_horizontally, merge_metadata_vertically, \
|
||||
merge_pair_horizontally, merge_pair_vertically, concat_images_horizontally, concat_images_vertically, \
|
||||
merge_group_horizontally, merge_group_vertically
|
||||
from test.conftest import (
|
||||
get_base_position_metadata,
|
||||
add_image,
|
||||
random_single_color_image_from_metadata,
|
||||
random_size_gray_image_from_metadata,
|
||||
@ -35,19 +42,6 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1",
|
||||
width_getter, height_getter = map(make_length_getter, ("width", "height"))
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("noise", [(0, 3)])
|
||||
# @pytest.mark.parametrize("split_count", [2])
|
||||
# def test_image_stitcher_with_gaps(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
|
||||
# pair_stitched = first(stitch_pairs(patch_image_metadata_pairs, tolerance=12))
|
||||
# pair_stitched.image.show()
|
||||
# # base_patch_image.show()
|
||||
# input()
|
||||
# import IPython
|
||||
# IPython.embed()
|
||||
# assert pair_stitched.metadata == base_patch_metadata
|
||||
# assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
||||
|
||||
|
||||
def test_group_by_coordinate_exact():
|
||||
pairs = [(0, 1), (0, 3), (1, 4), (1, 4), (1, 2), (3, 3)]
|
||||
pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=0))
|
||||
@ -61,11 +55,35 @@ def test_group_by_coordinate_fuzzy():
|
||||
|
||||
|
||||
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
|
||||
pair_stitched = first(stitch_pairs(patch_image_metadata_pairs))
|
||||
pairs_stitched = stitch_pairs(patch_image_metadata_pairs)
|
||||
pair_stitched = first(pairs_stitched)
|
||||
|
||||
assert len(pairs_stitched) == 1
|
||||
assert pair_stitched.metadata == base_patch_metadata
|
||||
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("noise", [(0, 2)])
|
||||
@pytest.mark.parametrize("split_count", [5])
|
||||
@pytest.mark.parametrize("width", [100])
|
||||
@pytest.mark.parametrize("height", [100])
|
||||
@pytest.mark.parametrize("page_width", [100])
|
||||
@pytest.mark.parametrize("page_height", [100])
|
||||
@pytest.mark.parametrize("execution_number", range(100))
|
||||
def test_image_stitcher_with_gaps(patch_image_metadata_pairs, base_patch_metadata, base_patch_image, execution_number):
|
||||
print(len(patch_image_metadata_pairs))
|
||||
pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=7)
|
||||
try:
|
||||
assert len(pairs_stitched) == 1
|
||||
except:
|
||||
for p in pairs_stitched:
|
||||
p.image.show()
|
||||
base_patch_image.show()
|
||||
import IPython
|
||||
|
||||
IPython.embed()
|
||||
|
||||
|
||||
def test_merge_group_horizontally(horizontal_merge_test_pairs):
|
||||
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
|
||||
|
||||
@ -211,13 +229,6 @@ def patch_image_metadata_pairs(patches_metadata) -> List[ImageMetadataPair]:
|
||||
return list(starmap(ImageMetadataPair, zip(images, patches_metadata)))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_patch_metadata(width, height, page_width, page_height):
|
||||
metadata = get_base_position_metadata(width, height, page_width, page_height)
|
||||
metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height})
|
||||
return metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patches_metadata(base_patch_metadata, noise, split_count):
|
||||
patches_metadata = list(BoxSplitter(noise).split_box(base_patch_metadata, split_count))
|
||||
|
||||
50
test/unit_tests/split_mapper_test.py
Normal file
50
test/unit_tests/split_mapper_test.py
Normal file
@ -0,0 +1,50 @@
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.stitching.split_mapper import VerticalSplitMapper, HorizontalSplitMapper
|
||||
|
||||
|
||||
def test_split_vertical_mapper(base_patch_metadata):
|
||||
sm = VerticalSplitMapper(base_patch_metadata)
|
||||
|
||||
sm.c1 += 10 + 3
|
||||
sm.c2 += 20 + 3
|
||||
sm.dim += 20 + 3
|
||||
smw = sm.wrapped
|
||||
|
||||
assert smw[Info.Y1] == sm.c1 == base_patch_metadata[Info.Y1] + 10 + 3
|
||||
assert smw[Info.Y2] == sm.c2 == base_patch_metadata[Info.Y2] + 20 + 3
|
||||
assert smw[Info.HEIGHT] == sm.dim == base_patch_metadata[Info.HEIGHT] + 20 + 3
|
||||
|
||||
sm = VerticalSplitMapper(base_patch_metadata)
|
||||
|
||||
sm.c1 = 10 + 3
|
||||
sm.c2 = 20 + 3
|
||||
sm.dim = 20 + 3
|
||||
smw = sm.wrapped
|
||||
|
||||
assert smw[Info.Y1] == sm.c1 == 10 + 3
|
||||
assert smw[Info.Y2] == sm.c2 == 20 + 3
|
||||
assert smw[Info.HEIGHT] == sm.dim == 20 + 3
|
||||
|
||||
|
||||
def test_split_horizontal_mapper(base_patch_metadata):
|
||||
sm = HorizontalSplitMapper(base_patch_metadata)
|
||||
|
||||
sm.c1 += 10 + 3
|
||||
sm.c2 += 20 + 3
|
||||
sm.dim += 20 + 3
|
||||
smw = sm.wrapped
|
||||
|
||||
assert smw[Info.X1] == sm.c1 == base_patch_metadata[Info.X1] + 10 + 3
|
||||
assert smw[Info.X2] == sm.c2 == base_patch_metadata[Info.X2] + 20 + 3
|
||||
assert smw[Info.WIDTH] == sm.dim == base_patch_metadata[Info.WIDTH] + 20 + 3
|
||||
|
||||
sm = HorizontalSplitMapper(base_patch_metadata)
|
||||
|
||||
sm.c1 = 10 + 3
|
||||
sm.c2 = 20 + 3
|
||||
sm.dim = 20 + 3
|
||||
smw = sm.wrapped
|
||||
|
||||
assert smw[Info.X1] == sm.c1 == 10 + 3
|
||||
assert smw[Info.X2] == sm.c2 == 20 + 3
|
||||
assert smw[Info.WIDTH] == sm.dim == 20 + 3
|
||||
@ -11,8 +11,6 @@ class BoxSplitter:
|
||||
def __init__(self, noise=None):
|
||||
self.__steps = None
|
||||
self.__noise = (0, 0) if not noise else noise
|
||||
if not min(self.__noise) >= 0:
|
||||
raise ValueError("Noise interval must be non-negative.")
|
||||
|
||||
def split_box(self, box, steps=5):
|
||||
self.__steps = steps
|
||||
@ -55,7 +53,7 @@ class BoxSplitter:
|
||||
)
|
||||
|
||||
def noise(self):
|
||||
return int(random.uniform(*self.__noise))
|
||||
return int(round(random.uniform(*self.__noise)))
|
||||
|
||||
@staticmethod
|
||||
def __large_enough(wrapped_box: SplitMapper):
|
||||
@ -64,14 +62,15 @@ class BoxSplitter:
|
||||
def __get_child_boxes(self, wrapped_box: SplitMapper):
|
||||
|
||||
split_len = random.randint(5, wrapped_box.dim - 5)
|
||||
split_point = wrapped_box.c1 + split_len + self.noise()
|
||||
split_point = wrapped_box.c1 + split_len
|
||||
|
||||
box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box)
|
||||
|
||||
box_left.dim = split_len + self.noise()
|
||||
box_right.dim = wrapped_box.dim - split_len + self.noise()
|
||||
noise = - self.noise()
|
||||
box_left.dim = split_len + noise
|
||||
box_right.dim = wrapped_box.dim - split_len
|
||||
|
||||
box_left.c2 = split_point + self.noise()
|
||||
box_left.c2 = split_point + noise
|
||||
box_right.c1 = split_point + self.noise()
|
||||
|
||||
return box_left.wrapped, box_right.wrapped
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user