fuzzy stitching completed

This commit is contained in:
Matthias Bisping 2022-04-12 15:04:32 +02:00
parent bb7c1be630
commit d8f86d14a5
7 changed files with 146 additions and 19 deletions

View File

@ -9,3 +9,15 @@ class EnumFormatter(KeyFormatter):
def transform(self, obj):
raise NotImplementedError
class ReverseEnumFormatter(KeyFormatter):
def __init__(self, enum):
self.enum = enum
self.reverse_enum = {e.value: e for e in enum}
def format_key(self, key):
return self.reverse_enum.get(key, key)
def transform(self, obj):
raise NotImplementedError

View File

@ -8,9 +8,9 @@ from funcy import juxt, first, rest, rcompose, rpartial
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.stitching.grouping import CoordGrouper
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once
from image_prediction.utils.generic import until
from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalSplitMapper
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once, validate_box
from image_prediction.utils.generic import until
def no_new_merges(pairs1, pairs2):
@ -139,13 +139,15 @@ def merge_metadata(m1: dict, m2: dict):
c1 = min(m1.c1, m2.c1)
c2 = max(m1.c2, m2.c2)
dim = m1.dim + m2.dim
dim = abs(c2 - c1)
merged = deepcopy(m1)
merged.dim = dim
merged.c1 = c1
merged.c2 = c2
validate_box(merged.wrapped)
return merged.wrapped
@ -163,7 +165,7 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
images = [im1, im2]
offsets = [0, *[im.size[axis] for im in images]]
offsets = 0, im1.size[axis], im_aggr.size[axis] - im2.size[axis]
for im, offset in zip(images, offsets):
box = (offset, 0) if not axis else (0, offset)

View File

@ -28,3 +28,8 @@ def make_length_getter(dim):
"width": make_getter(Info.WIDTH),
"height": make_getter(Info.HEIGHT),
}[dim]
def validate_box(box):
assert box[Info.X2] - box[Info.X1] == box[Info.WIDTH]
assert box[Info.Y2] - box[Info.Y1] == box[Info.HEIGHT]

View File

@ -494,6 +494,7 @@ def random_single_color_image_from_metadata(metadata):
return image
# TODO: rename: not random!
def random_size_gray_image_from_metadata(metadata):
image = Image.new("RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=(100, 100, 100))
return image

View File

@ -0,0 +1,92 @@
{
"input": [
{
"width": 100,
"height": 8,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 0,
"y1": 0,
"x2": 100,
"y2": 8
},
{
"width": 100,
"height": 9,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 0,
"y1": 9,
"x2": 100,
"y2": 18
},
{
"width": 100,
"height": 35,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 0,
"y1": 18,
"x2": 100,
"y2": 53
},
{
"width": 47,
"height": 46,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 0,
"y1": 54,
"x2": 47,
"y2": 100
},
{
"width": 31,
"height": 46,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 48,
"y1": 54,
"x2": 79,
"y2": 100
},
{
"width": 20,
"height": 19,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 80,
"y1": 54,
"x2": 100,
"y2": 73
},
{
"width": 20,
"height": 27,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 80,
"y1": 73,
"x2": 100,
"y2": 100
}
],
"target": {
"width": 100,
"height": 100,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 0,
"y1": 0,
"x2": 100,
"y2": 100
}
}

View File

@ -1,3 +1,5 @@
import json
import os
from copy import deepcopy
from copy import deepcopy
from functools import partial
@ -13,6 +15,7 @@ from PIL import Image
from funcy import juxt, one, first
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
from image_prediction.formatter.formatters.enum import ReverseEnumFormatter
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.stitching.grouping import group_by_coordinate
@ -63,6 +66,22 @@ def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_pa
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
def test_image_stitcher_with_gaps_must_succeed():
from image_prediction.locations import TEST_DATA_DIR
with open(os.path.join(TEST_DATA_DIR, "stitching_with_tolerance.json")) as f:
patches_metadata, base_patch_metadata = itemgetter("input", "target")(ReverseEnumFormatter(Info)(json.load(f)))
images = map(random_size_gray_image_from_metadata, patches_metadata)
patch_image_metadata_pairs = list(starmap(ImageMetadataPair, zip(images, patches_metadata)))
pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=7)
assert len(pairs_stitched) == 1
pair_stitched = first(pairs_stitched)
assert pair_stitched.metadata == base_patch_metadata
@pytest.mark.parametrize("noise", [(0, 2)])
@pytest.mark.parametrize("split_count", [5])
@pytest.mark.parametrize("width", [100])
@ -70,18 +89,10 @@ def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_pa
@pytest.mark.parametrize("page_width", [100])
@pytest.mark.parametrize("page_height", [100])
@pytest.mark.parametrize("execution_number", range(100))
def test_image_stitcher_with_gaps(patch_image_metadata_pairs, base_patch_metadata, base_patch_image, execution_number):
print(len(patch_image_metadata_pairs))
pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=7)
try:
assert len(pairs_stitched) == 1
except:
for p in pairs_stitched:
p.image.show()
base_patch_image.show()
import IPython
IPython.embed()
@pytest.mark.xfail(reason="Does not always succeed due to locally maximizing merging logic.")
def test_image_stitcher_with_gaps_can_fail(patch_image_metadata_pairs, base_patch_metadata, execution_number):
pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=4)
assert len(pairs_stitched) == 1 and first(pairs_stitched).metadata == base_patch_metadata
def test_merge_group_horizontally(horizontal_merge_test_pairs):

View File

@ -2,9 +2,10 @@ import random
from copy import deepcopy
from itertools import chain
from funcy import rpartial, juxt, first
from funcy import rpartial, juxt
from image_prediction.stitching.split_mapper import SplitMapper, HorizontalSplitMapper, VerticalSplitMapper
from image_prediction.stitching.utils import validate_box
class BoxSplitter:
@ -66,11 +67,14 @@ class BoxSplitter:
box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box)
noise = - self.noise()
noise = -self.noise()
box_left.dim = split_len + noise
box_right.dim = wrapped_box.dim - split_len
box_left.c2 = split_point + noise
box_right.c1 = split_point + self.noise()
box_right.c1 = split_point
validate_box(box_left.wrapped)
validate_box(box_right.wrapped)
return box_left.wrapped, box_right.wrapped