fuzzy stitching WIP: mostly works, but sometimes fails. run test_image_stitcher_with_gaps to debug
This commit is contained in:
parent
79cd31850d
commit
bb7c1be630
@ -29,7 +29,6 @@ def fuzzify(func, tolerance):
|
|||||||
nonlocal mid_points
|
nonlocal mid_points
|
||||||
nonlocal lower_bounds
|
nonlocal lower_bounds
|
||||||
nonlocal upper_bounds
|
nonlocal upper_bounds
|
||||||
print(tolerance)
|
|
||||||
|
|
||||||
value = func(item)
|
value = func(item)
|
||||||
fits = (array(lower_bounds_array()) <= value) & (value <= array(upper_bounds_array()))
|
fits = (array(lower_bounds_array()) <= value) & (value <= array(upper_bounds_array()))
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from functools import reduce
|
|||||||
from typing import Iterable, Callable, List
|
from typing import Iterable, Callable, List
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from funcy import juxt, first, rest, rcompose
|
from funcy import juxt, first, rest, rcompose, rpartial
|
||||||
|
|
||||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
from image_prediction.info import Info
|
from image_prediction.info import Info
|
||||||
@ -44,7 +44,7 @@ def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis, tolerance=0) -> I
|
|||||||
return map(CoordGrouper(axis, tolerance=tolerance).group_pairs_by_greater_coordinate, groups)
|
return map(CoordGrouper(axis, tolerance=tolerance).group_pairs_by_greater_coordinate, groups)
|
||||||
|
|
||||||
def merge_groups_along_orthogonal_axis(groups):
|
def merge_groups_along_orthogonal_axis(groups):
|
||||||
return map(make_group_merger(axis), groups)
|
return map(rpartial(make_group_merger(axis), tolerance), groups)
|
||||||
|
|
||||||
def group_pairs_by_lesser_coordinate(pairs):
|
def group_pairs_by_lesser_coordinate(pairs):
|
||||||
return CoordGrouper(axis, tolerance=tolerance).group_pairs_by_lesser_coordinate(pairs)
|
return CoordGrouper(axis, tolerance=tolerance).group_pairs_by_lesser_coordinate(pairs)
|
||||||
@ -62,29 +62,33 @@ def make_group_merger(axis):
|
|||||||
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
|
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
|
||||||
|
|
||||||
|
|
||||||
def merge_group_vertically(group: Iterable[ImageMetadataPair]):
|
def merge_group_vertically(group: Iterable[ImageMetadataPair], tolerance=0):
|
||||||
return merge_group(group, "y")
|
return merge_group(group, "y", tolerance=tolerance)
|
||||||
|
|
||||||
|
|
||||||
def merge_group_horizontally(group: Iterable[ImageMetadataPair]):
|
def merge_group_horizontally(group: Iterable[ImageMetadataPair], tolerance=0):
|
||||||
return merge_group(group, "x")
|
return merge_group(group, "x", tolerance=tolerance)
|
||||||
|
|
||||||
|
|
||||||
def merge_group(group: Iterable[ImageMetadataPair], direction):
|
def merge_group(group: Iterable[ImageMetadataPair], direction, tolerance=0):
|
||||||
reduce_group = make_merger_aggregator(direction)
|
reduce_group = make_merger_aggregator(direction, tolerance=tolerance)
|
||||||
return until(no_new_merges, reduce_group, group)
|
return until(no_new_merges, reduce_group, group)
|
||||||
|
|
||||||
|
|
||||||
def make_merger_aggregator(axis) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:
|
def make_merger_aggregator(axis, tolerance=0) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:
|
||||||
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
||||||
head H and aggregates non-adjacent in the tail T.
|
head H and aggregates non-adjacent in the tail T.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
When tolerance > 0, the bounding box of the merged image no longer matches the bounding box of the mereged
|
||||||
|
metadata. This is intended behaviour, but might be not be expected by the caller.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def merger_aggregator(pairs: Iterable[ImageMetadataPair]):
|
def merger_aggregator(pairs: Iterable[ImageMetadataPair]):
|
||||||
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
||||||
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
||||||
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
||||||
if c2_getter(aggr) == c1_getter(pair):
|
if abs(c2_getter(aggr) - c1_getter(pair)) <= tolerance:
|
||||||
aggr = pair_merger(aggr, pair)
|
aggr = pair_merger(aggr, pair)
|
||||||
return aggr, *non_aggr
|
return aggr, *non_aggr
|
||||||
else:
|
else:
|
||||||
@ -96,6 +100,8 @@ def make_merger_aggregator(axis) -> Callable[[Iterable[ImageMetadataPair]], Iter
|
|||||||
head_pair, pairs = juxt(first, rest)(pairs)
|
head_pair, pairs = juxt(first, rest)(pairs)
|
||||||
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [head_pair]))
|
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [head_pair]))
|
||||||
|
|
||||||
|
assert tolerance >= 0
|
||||||
|
|
||||||
c1_getter = make_coord_getter(f"{axis}1")
|
c1_getter = make_coord_getter(f"{axis}1")
|
||||||
c2_getter = make_coord_getter(f"{axis}2")
|
c2_getter = make_coord_getter(f"{axis}2")
|
||||||
pair_merger = make_pair_merger(axis)
|
pair_merger = make_pair_merger(axis)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Iterable
|
from typing import Iterable, List
|
||||||
|
|
||||||
from funcy import rpartial
|
from funcy import rpartial
|
||||||
|
|
||||||
@ -7,7 +7,7 @@ from image_prediction.stitching.merging import merge_along_both_axes, no_new_mer
|
|||||||
from image_prediction.utils.generic import until
|
from image_prediction.utils.generic import until
|
||||||
|
|
||||||
|
|
||||||
def stitch_pairs(pairs: Iterable[ImageMetadataPair], tolerance=0) -> Iterable[ImageMetadataPair]:
|
def stitch_pairs(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]:
|
||||||
"""Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent
|
"""Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent
|
||||||
images."""
|
images."""
|
||||||
return until(no_new_merges, rpartial(merge_along_both_axes, tolerance), pairs)
|
return until(no_new_merges, rpartial(merge_along_both_axes, tolerance), pairs)
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import fpdf
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from funcy import rcompose
|
from funcy import rcompose, merge
|
||||||
|
|
||||||
from image_prediction.classifier.classifier import Classifier
|
from image_prediction.classifier.classifier import Classifier
|
||||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||||
@ -460,6 +460,13 @@ def get_base_position_metadata(width, height, page_width, page_height):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base_patch_metadata(width, height, page_width, page_height):
|
||||||
|
metadata = get_base_position_metadata(width, height, page_width, page_height)
|
||||||
|
metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height})
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=[33, 100])
|
@pytest.fixture(params=[33, 100])
|
||||||
def height(request):
|
def height(request):
|
||||||
return request.param
|
return request.param
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import starmap, repeat
|
from itertools import starmap, repeat
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@ -9,22 +10,28 @@ import numpy as np
|
|||||||
import pdf2image
|
import pdf2image
|
||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from funcy import merge, juxt, one, first
|
from funcy import juxt, one, first
|
||||||
|
|
||||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
||||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
from image_prediction.info import Info
|
from image_prediction.info import Info
|
||||||
from image_prediction.stitching.grouping import group_by_coordinate
|
from image_prediction.stitching.grouping import group_by_coordinate
|
||||||
|
from image_prediction.stitching.merging import (
|
||||||
|
merge_metadata_horizontally,
|
||||||
|
merge_metadata_vertically,
|
||||||
|
merge_pair_horizontally,
|
||||||
|
merge_pair_vertically,
|
||||||
|
concat_images_horizontally,
|
||||||
|
concat_images_vertically,
|
||||||
|
merge_group_horizontally,
|
||||||
|
merge_group_vertically,
|
||||||
|
)
|
||||||
from image_prediction.stitching.stitching import stitch_pairs
|
from image_prediction.stitching.stitching import stitch_pairs
|
||||||
from image_prediction.stitching.utils import (
|
from image_prediction.stitching.utils import (
|
||||||
make_coord_getter,
|
make_coord_getter,
|
||||||
make_length_getter,
|
make_length_getter,
|
||||||
)
|
)
|
||||||
from image_prediction.stitching.merging import merge_metadata_horizontally, merge_metadata_vertically, \
|
|
||||||
merge_pair_horizontally, merge_pair_vertically, concat_images_horizontally, concat_images_vertically, \
|
|
||||||
merge_group_horizontally, merge_group_vertically
|
|
||||||
from test.conftest import (
|
from test.conftest import (
|
||||||
get_base_position_metadata,
|
|
||||||
add_image,
|
add_image,
|
||||||
random_single_color_image_from_metadata,
|
random_single_color_image_from_metadata,
|
||||||
random_size_gray_image_from_metadata,
|
random_size_gray_image_from_metadata,
|
||||||
@ -35,19 +42,6 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1",
|
|||||||
width_getter, height_getter = map(make_length_getter, ("width", "height"))
|
width_getter, height_getter = map(make_length_getter, ("width", "height"))
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("noise", [(0, 3)])
|
|
||||||
# @pytest.mark.parametrize("split_count", [2])
|
|
||||||
# def test_image_stitcher_with_gaps(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
|
|
||||||
# pair_stitched = first(stitch_pairs(patch_image_metadata_pairs, tolerance=12))
|
|
||||||
# pair_stitched.image.show()
|
|
||||||
# # base_patch_image.show()
|
|
||||||
# input()
|
|
||||||
# import IPython
|
|
||||||
# IPython.embed()
|
|
||||||
# assert pair_stitched.metadata == base_patch_metadata
|
|
||||||
# assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
|
||||||
|
|
||||||
|
|
||||||
def test_group_by_coordinate_exact():
|
def test_group_by_coordinate_exact():
|
||||||
pairs = [(0, 1), (0, 3), (1, 4), (1, 4), (1, 2), (3, 3)]
|
pairs = [(0, 1), (0, 3), (1, 4), (1, 4), (1, 2), (3, 3)]
|
||||||
pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=0))
|
pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=0))
|
||||||
@ -61,11 +55,35 @@ def test_group_by_coordinate_fuzzy():
|
|||||||
|
|
||||||
|
|
||||||
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
|
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
|
||||||
pair_stitched = first(stitch_pairs(patch_image_metadata_pairs))
|
pairs_stitched = stitch_pairs(patch_image_metadata_pairs)
|
||||||
|
pair_stitched = first(pairs_stitched)
|
||||||
|
|
||||||
|
assert len(pairs_stitched) == 1
|
||||||
assert pair_stitched.metadata == base_patch_metadata
|
assert pair_stitched.metadata == base_patch_metadata
|
||||||
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("noise", [(0, 2)])
|
||||||
|
@pytest.mark.parametrize("split_count", [5])
|
||||||
|
@pytest.mark.parametrize("width", [100])
|
||||||
|
@pytest.mark.parametrize("height", [100])
|
||||||
|
@pytest.mark.parametrize("page_width", [100])
|
||||||
|
@pytest.mark.parametrize("page_height", [100])
|
||||||
|
@pytest.mark.parametrize("execution_number", range(100))
|
||||||
|
def test_image_stitcher_with_gaps(patch_image_metadata_pairs, base_patch_metadata, base_patch_image, execution_number):
|
||||||
|
print(len(patch_image_metadata_pairs))
|
||||||
|
pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=7)
|
||||||
|
try:
|
||||||
|
assert len(pairs_stitched) == 1
|
||||||
|
except:
|
||||||
|
for p in pairs_stitched:
|
||||||
|
p.image.show()
|
||||||
|
base_patch_image.show()
|
||||||
|
import IPython
|
||||||
|
|
||||||
|
IPython.embed()
|
||||||
|
|
||||||
|
|
||||||
def test_merge_group_horizontally(horizontal_merge_test_pairs):
|
def test_merge_group_horizontally(horizontal_merge_test_pairs):
|
||||||
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
|
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
|
||||||
|
|
||||||
@ -211,13 +229,6 @@ def patch_image_metadata_pairs(patches_metadata) -> List[ImageMetadataPair]:
|
|||||||
return list(starmap(ImageMetadataPair, zip(images, patches_metadata)))
|
return list(starmap(ImageMetadataPair, zip(images, patches_metadata)))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def base_patch_metadata(width, height, page_width, page_height):
|
|
||||||
metadata = get_base_position_metadata(width, height, page_width, page_height)
|
|
||||||
metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height})
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def patches_metadata(base_patch_metadata, noise, split_count):
|
def patches_metadata(base_patch_metadata, noise, split_count):
|
||||||
patches_metadata = list(BoxSplitter(noise).split_box(base_patch_metadata, split_count))
|
patches_metadata = list(BoxSplitter(noise).split_box(base_patch_metadata, split_count))
|
||||||
|
|||||||
50
test/unit_tests/split_mapper_test.py
Normal file
50
test/unit_tests/split_mapper_test.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from image_prediction.info import Info
|
||||||
|
from image_prediction.stitching.split_mapper import VerticalSplitMapper, HorizontalSplitMapper
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_vertical_mapper(base_patch_metadata):
|
||||||
|
sm = VerticalSplitMapper(base_patch_metadata)
|
||||||
|
|
||||||
|
sm.c1 += 10 + 3
|
||||||
|
sm.c2 += 20 + 3
|
||||||
|
sm.dim += 20 + 3
|
||||||
|
smw = sm.wrapped
|
||||||
|
|
||||||
|
assert smw[Info.Y1] == sm.c1 == base_patch_metadata[Info.Y1] + 10 + 3
|
||||||
|
assert smw[Info.Y2] == sm.c2 == base_patch_metadata[Info.Y2] + 20 + 3
|
||||||
|
assert smw[Info.HEIGHT] == sm.dim == base_patch_metadata[Info.HEIGHT] + 20 + 3
|
||||||
|
|
||||||
|
sm = VerticalSplitMapper(base_patch_metadata)
|
||||||
|
|
||||||
|
sm.c1 = 10 + 3
|
||||||
|
sm.c2 = 20 + 3
|
||||||
|
sm.dim = 20 + 3
|
||||||
|
smw = sm.wrapped
|
||||||
|
|
||||||
|
assert smw[Info.Y1] == sm.c1 == 10 + 3
|
||||||
|
assert smw[Info.Y2] == sm.c2 == 20 + 3
|
||||||
|
assert smw[Info.HEIGHT] == sm.dim == 20 + 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_horizontal_mapper(base_patch_metadata):
|
||||||
|
sm = HorizontalSplitMapper(base_patch_metadata)
|
||||||
|
|
||||||
|
sm.c1 += 10 + 3
|
||||||
|
sm.c2 += 20 + 3
|
||||||
|
sm.dim += 20 + 3
|
||||||
|
smw = sm.wrapped
|
||||||
|
|
||||||
|
assert smw[Info.X1] == sm.c1 == base_patch_metadata[Info.X1] + 10 + 3
|
||||||
|
assert smw[Info.X2] == sm.c2 == base_patch_metadata[Info.X2] + 20 + 3
|
||||||
|
assert smw[Info.WIDTH] == sm.dim == base_patch_metadata[Info.WIDTH] + 20 + 3
|
||||||
|
|
||||||
|
sm = HorizontalSplitMapper(base_patch_metadata)
|
||||||
|
|
||||||
|
sm.c1 = 10 + 3
|
||||||
|
sm.c2 = 20 + 3
|
||||||
|
sm.dim = 20 + 3
|
||||||
|
smw = sm.wrapped
|
||||||
|
|
||||||
|
assert smw[Info.X1] == sm.c1 == 10 + 3
|
||||||
|
assert smw[Info.X2] == sm.c2 == 20 + 3
|
||||||
|
assert smw[Info.WIDTH] == sm.dim == 20 + 3
|
||||||
@ -11,8 +11,6 @@ class BoxSplitter:
|
|||||||
def __init__(self, noise=None):
|
def __init__(self, noise=None):
|
||||||
self.__steps = None
|
self.__steps = None
|
||||||
self.__noise = (0, 0) if not noise else noise
|
self.__noise = (0, 0) if not noise else noise
|
||||||
if not min(self.__noise) >= 0:
|
|
||||||
raise ValueError("Noise interval must be non-negative.")
|
|
||||||
|
|
||||||
def split_box(self, box, steps=5):
|
def split_box(self, box, steps=5):
|
||||||
self.__steps = steps
|
self.__steps = steps
|
||||||
@ -55,7 +53,7 @@ class BoxSplitter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def noise(self):
|
def noise(self):
|
||||||
return int(random.uniform(*self.__noise))
|
return int(round(random.uniform(*self.__noise)))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __large_enough(wrapped_box: SplitMapper):
|
def __large_enough(wrapped_box: SplitMapper):
|
||||||
@ -64,14 +62,15 @@ class BoxSplitter:
|
|||||||
def __get_child_boxes(self, wrapped_box: SplitMapper):
|
def __get_child_boxes(self, wrapped_box: SplitMapper):
|
||||||
|
|
||||||
split_len = random.randint(5, wrapped_box.dim - 5)
|
split_len = random.randint(5, wrapped_box.dim - 5)
|
||||||
split_point = wrapped_box.c1 + split_len + self.noise()
|
split_point = wrapped_box.c1 + split_len
|
||||||
|
|
||||||
box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box)
|
box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box)
|
||||||
|
|
||||||
box_left.dim = split_len + self.noise()
|
noise = - self.noise()
|
||||||
box_right.dim = wrapped_box.dim - split_len + self.noise()
|
box_left.dim = split_len + noise
|
||||||
|
box_right.dim = wrapped_box.dim - split_len
|
||||||
|
|
||||||
box_left.c2 = split_point + self.noise()
|
box_left.c2 = split_point + noise
|
||||||
box_right.c1 = split_point + self.noise()
|
box_right.c1 = split_point + self.noise()
|
||||||
|
|
||||||
return box_left.wrapped, box_right.wrapped
|
return box_left.wrapped, box_right.wrapped
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user