fuzzy stitching completed
This commit is contained in:
parent
bb7c1be630
commit
d8f86d14a5
@ -9,3 +9,15 @@ class EnumFormatter(KeyFormatter):
|
|||||||
|
|
||||||
def transform(self, obj):
|
def transform(self, obj):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ReverseEnumFormatter(KeyFormatter):
|
||||||
|
def __init__(self, enum):
|
||||||
|
self.enum = enum
|
||||||
|
self.reverse_enum = {e.value: e for e in enum}
|
||||||
|
|
||||||
|
def format_key(self, key):
|
||||||
|
return self.reverse_enum.get(key, key)
|
||||||
|
|
||||||
|
def transform(self, obj):
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@ -8,9 +8,9 @@ from funcy import juxt, first, rest, rcompose, rpartial
|
|||||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
from image_prediction.info import Info
|
from image_prediction.info import Info
|
||||||
from image_prediction.stitching.grouping import CoordGrouper
|
from image_prediction.stitching.grouping import CoordGrouper
|
||||||
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once
|
|
||||||
from image_prediction.utils.generic import until
|
|
||||||
from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalSplitMapper
|
from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalSplitMapper
|
||||||
|
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once, validate_box
|
||||||
|
from image_prediction.utils.generic import until
|
||||||
|
|
||||||
|
|
||||||
def no_new_merges(pairs1, pairs2):
|
def no_new_merges(pairs1, pairs2):
|
||||||
@ -139,13 +139,15 @@ def merge_metadata(m1: dict, m2: dict):
|
|||||||
|
|
||||||
c1 = min(m1.c1, m2.c1)
|
c1 = min(m1.c1, m2.c1)
|
||||||
c2 = max(m1.c2, m2.c2)
|
c2 = max(m1.c2, m2.c2)
|
||||||
dim = m1.dim + m2.dim
|
dim = abs(c2 - c1)
|
||||||
|
|
||||||
merged = deepcopy(m1)
|
merged = deepcopy(m1)
|
||||||
merged.dim = dim
|
merged.dim = dim
|
||||||
merged.c1 = c1
|
merged.c1 = c1
|
||||||
merged.c2 = c2
|
merged.c2 = c2
|
||||||
|
|
||||||
|
validate_box(merged.wrapped)
|
||||||
|
|
||||||
return merged.wrapped
|
return merged.wrapped
|
||||||
|
|
||||||
|
|
||||||
@ -163,7 +165,7 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
|||||||
|
|
||||||
images = [im1, im2]
|
images = [im1, im2]
|
||||||
|
|
||||||
offsets = [0, *[im.size[axis] for im in images]]
|
offsets = 0, im1.size[axis], im_aggr.size[axis] - im2.size[axis]
|
||||||
|
|
||||||
for im, offset in zip(images, offsets):
|
for im, offset in zip(images, offsets):
|
||||||
box = (offset, 0) if not axis else (0, offset)
|
box = (offset, 0) if not axis else (0, offset)
|
||||||
|
|||||||
@ -28,3 +28,8 @@ def make_length_getter(dim):
|
|||||||
"width": make_getter(Info.WIDTH),
|
"width": make_getter(Info.WIDTH),
|
||||||
"height": make_getter(Info.HEIGHT),
|
"height": make_getter(Info.HEIGHT),
|
||||||
}[dim]
|
}[dim]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_box(box):
|
||||||
|
assert box[Info.X2] - box[Info.X1] == box[Info.WIDTH]
|
||||||
|
assert box[Info.Y2] - box[Info.Y1] == box[Info.HEIGHT]
|
||||||
@ -494,6 +494,7 @@ def random_single_color_image_from_metadata(metadata):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: rename: not random!
|
||||||
def random_size_gray_image_from_metadata(metadata):
|
def random_size_gray_image_from_metadata(metadata):
|
||||||
image = Image.new("RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=(100, 100, 100))
|
image = Image.new("RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=(100, 100, 100))
|
||||||
return image
|
return image
|
||||||
|
|||||||
92
test/data/stitching_with_tolerance.json
Normal file
92
test/data/stitching_with_tolerance.json
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
{
|
||||||
|
"input": [
|
||||||
|
{
|
||||||
|
"width": 100,
|
||||||
|
"height": 8,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 0,
|
||||||
|
"y1": 0,
|
||||||
|
"x2": 100,
|
||||||
|
"y2": 8
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"width": 100,
|
||||||
|
"height": 9,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 0,
|
||||||
|
"y1": 9,
|
||||||
|
"x2": 100,
|
||||||
|
"y2": 18
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"width": 100,
|
||||||
|
"height": 35,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 0,
|
||||||
|
"y1": 18,
|
||||||
|
"x2": 100,
|
||||||
|
"y2": 53
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"width": 47,
|
||||||
|
"height": 46,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 0,
|
||||||
|
"y1": 54,
|
||||||
|
"x2": 47,
|
||||||
|
"y2": 100
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"width": 31,
|
||||||
|
"height": 46,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 48,
|
||||||
|
"y1": 54,
|
||||||
|
"x2": 79,
|
||||||
|
"y2": 100
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"width": 20,
|
||||||
|
"height": 19,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 80,
|
||||||
|
"y1": 54,
|
||||||
|
"x2": 100,
|
||||||
|
"y2": 73
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"width": 20,
|
||||||
|
"height": 27,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 80,
|
||||||
|
"y1": 73,
|
||||||
|
"x2": 100,
|
||||||
|
"y2": 100
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"target": {
|
||||||
|
"width": 100,
|
||||||
|
"height": 100,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 0,
|
||||||
|
"y1": 0,
|
||||||
|
"x2": 100,
|
||||||
|
"y2": 100
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,3 +1,5 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -13,6 +15,7 @@ from PIL import Image
|
|||||||
from funcy import juxt, one, first
|
from funcy import juxt, one, first
|
||||||
|
|
||||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
||||||
|
from image_prediction.formatter.formatters.enum import ReverseEnumFormatter
|
||||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
from image_prediction.info import Info
|
from image_prediction.info import Info
|
||||||
from image_prediction.stitching.grouping import group_by_coordinate
|
from image_prediction.stitching.grouping import group_by_coordinate
|
||||||
@ -63,6 +66,22 @@ def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_pa
|
|||||||
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_stitcher_with_gaps_must_succeed():
|
||||||
|
from image_prediction.locations import TEST_DATA_DIR
|
||||||
|
|
||||||
|
with open(os.path.join(TEST_DATA_DIR, "stitching_with_tolerance.json")) as f:
|
||||||
|
patches_metadata, base_patch_metadata = itemgetter("input", "target")(ReverseEnumFormatter(Info)(json.load(f)))
|
||||||
|
|
||||||
|
images = map(random_size_gray_image_from_metadata, patches_metadata)
|
||||||
|
patch_image_metadata_pairs = list(starmap(ImageMetadataPair, zip(images, patches_metadata)))
|
||||||
|
|
||||||
|
pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=7)
|
||||||
|
|
||||||
|
assert len(pairs_stitched) == 1
|
||||||
|
pair_stitched = first(pairs_stitched)
|
||||||
|
assert pair_stitched.metadata == base_patch_metadata
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("noise", [(0, 2)])
|
@pytest.mark.parametrize("noise", [(0, 2)])
|
||||||
@pytest.mark.parametrize("split_count", [5])
|
@pytest.mark.parametrize("split_count", [5])
|
||||||
@pytest.mark.parametrize("width", [100])
|
@pytest.mark.parametrize("width", [100])
|
||||||
@ -70,18 +89,10 @@ def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_pa
|
|||||||
@pytest.mark.parametrize("page_width", [100])
|
@pytest.mark.parametrize("page_width", [100])
|
||||||
@pytest.mark.parametrize("page_height", [100])
|
@pytest.mark.parametrize("page_height", [100])
|
||||||
@pytest.mark.parametrize("execution_number", range(100))
|
@pytest.mark.parametrize("execution_number", range(100))
|
||||||
def test_image_stitcher_with_gaps(patch_image_metadata_pairs, base_patch_metadata, base_patch_image, execution_number):
|
@pytest.mark.xfail(reason="Does not always succeed due to locally maximizing merging logic.")
|
||||||
print(len(patch_image_metadata_pairs))
|
def test_image_stitcher_with_gaps_can_fail(patch_image_metadata_pairs, base_patch_metadata, execution_number):
|
||||||
pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=7)
|
pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=4)
|
||||||
try:
|
assert len(pairs_stitched) == 1 and first(pairs_stitched).metadata == base_patch_metadata
|
||||||
assert len(pairs_stitched) == 1
|
|
||||||
except:
|
|
||||||
for p in pairs_stitched:
|
|
||||||
p.image.show()
|
|
||||||
base_patch_image.show()
|
|
||||||
import IPython
|
|
||||||
|
|
||||||
IPython.embed()
|
|
||||||
|
|
||||||
|
|
||||||
def test_merge_group_horizontally(horizontal_merge_test_pairs):
|
def test_merge_group_horizontally(horizontal_merge_test_pairs):
|
||||||
|
|||||||
@ -2,9 +2,10 @@ import random
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
from funcy import rpartial, juxt, first
|
from funcy import rpartial, juxt
|
||||||
|
|
||||||
from image_prediction.stitching.split_mapper import SplitMapper, HorizontalSplitMapper, VerticalSplitMapper
|
from image_prediction.stitching.split_mapper import SplitMapper, HorizontalSplitMapper, VerticalSplitMapper
|
||||||
|
from image_prediction.stitching.utils import validate_box
|
||||||
|
|
||||||
|
|
||||||
class BoxSplitter:
|
class BoxSplitter:
|
||||||
@ -71,6 +72,9 @@ class BoxSplitter:
|
|||||||
box_right.dim = wrapped_box.dim - split_len
|
box_right.dim = wrapped_box.dim - split_len
|
||||||
|
|
||||||
box_left.c2 = split_point + noise
|
box_left.c2 = split_point + noise
|
||||||
box_right.c1 = split_point + self.noise()
|
box_right.c1 = split_point
|
||||||
|
|
||||||
|
validate_box(box_left.wrapped)
|
||||||
|
validate_box(box_right.wrapped)
|
||||||
|
|
||||||
return box_left.wrapped, box_right.wrapped
|
return box_left.wrapped, box_right.wrapped
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user