group merging wip

This commit is contained in:
Matthias Bisping 2022-04-07 17:18:09 +02:00
parent 2b1e7cbb08
commit 3e882dc247

View File

@ -49,15 +49,6 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1",
width_getter, height_getter = map(make_length_getter, ("width", "height")) width_getter, height_getter = map(make_length_getter, ("width", "height"))
def merge_group(group, axis="y"):
group = list(group)
current_pair = group.pop(0)
for pair in group:
if y2_getter(current_pair) == y1_getter(pair):
current_box = merge_pair_vertically(current_pair, pair)
def merge_metadata_horizontally(m1, m2): def merge_metadata_horizontally(m1, m2):
m1, m2 = map(HorizontalKeyMapper, [m1, m2]) m1, m2 = map(HorizontalKeyMapper, [m1, m2])
return merge_metadata(m1, m2) return merge_metadata(m1, m2)
@ -121,25 +112,73 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
return im_aggr return im_aggr
class Stitcher:
@staticmethod
def groupby(pairs, coord):
coord_getter = make_coord_getter(coord)
pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter))
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
groups = self.groupby(pairs, "x1")
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
groups = map(partial(sorted, key=y1_getter), groups)
groups = map(merge_group, groups)
def merge_group(group):
def f(pairs):
current_pair = pairs.pop(0)
to_remove = []
for pair in pairs:
if y2_getter(current_pair) == y1_getter(pair):
current_pair = merge_pair_vertically(current_pair, pair)
to_remove.append(pair)
return [current_pair, *filter(lambda p: p not in to_remove, pairs)]
pairs = list(group)
while True:
new_pairs = f(deepcopy(pairs))
if len(new_pairs) == len(pairs):
break
pairs = new_pairs
return new_pairs
##################################### #####################################
def test_merge_group(vertical_merge_test_pairs):
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
prs_merged = merge_group([pr1, pr2])
assert len(prs_merged) == 1
assert_pair_equal(prs_merged[0], pr_merged_expected)
def assert_pair_equal(pr1, pr2):
assert pr1.metadata == pr2.metadata
assert images_equal(pr1.image, pr2.image)
def test_merge_pairs_horizontally(horizontal_merge_test_pairs): def test_merge_pairs_horizontally(horizontal_merge_test_pairs):
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
pr_merged = merge_pair_horizontally(pr1, pr2) pr_merged = merge_pair_horizontally(pr1, pr2)
assert pr_merged.metadata == pr_merged_expected.metadata assert_pair_equal(pr_merged, pr_merged_expected)
images_equal(pr_merged.image, pr_merged_expected.image)
def test_merge_pairs_vertically(vertical_merge_test_pairs): def test_merge_pairs_vertically(vertical_merge_test_pairs):
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
pr_merged = merge_pair_vertically(pr1, pr2) pr_merged = merge_pair_vertically(pr1, pr2)
assert pr_merged.metadata == pr_merged_expected.metadata assert_pair_equal(pr_merged, pr_merged_expected)
images_equal(pr_merged.image, pr_merged_expected.image)
def images_equal(im1: Image, im2: Image): def images_equal(im1: Image, im2: Image):
assert np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2)) return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2))
@pytest.fixture @pytest.fixture
@ -200,28 +239,15 @@ def test_concat_images_horizontally(horizontal_merge_test_metadata):
im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged]) im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged])
im_merged = concat_images_horizontally(im1, im2, mdat_merged) im_merged = concat_images_horizontally(im1, im2, mdat_merged)
assert im_merged.size == im_merged_expected.size assert im_merged.size == im_merged_expected.size
images_equal(im_merged, im_merged_expected) assert images_equal(im_merged, im_merged_expected)
def test_concat_images_vertically(vertical_merge_test_metadata): def test_concat_images_vertically(vertical_merge_test_metadata):
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
im1, im2, im_merged_expected = map(random_single_color_image_from_metadata, [mdat1, mdat2, mdat_merged]) im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged])
im_merged = concat_images_vertically(im1, im2, mdat_merged) im_merged = concat_images_vertically(im1, im2, mdat_merged)
assert im_merged.size == im_merged_expected.size assert im_merged.size == im_merged_expected.size
assert images_equal(im_merged, im_merged_expected)
class Stitcher:
@staticmethod
def groupby(pairs, coord):
coord_getter = make_coord_getter(coord)
pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter))
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
groups = self.groupby(pairs, "x1")
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
groups = map(partial(sorted, key=y1_getter), groups)
groups = map(merge_group, groups)
@pytest.mark.parametrize("width", [160]) @pytest.mark.parametrize("width", [160])