group merging wip
This commit is contained in:
parent
2b1e7cbb08
commit
3e882dc247
@ -49,15 +49,6 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1",
|
|||||||
width_getter, height_getter = map(make_length_getter, ("width", "height"))
|
width_getter, height_getter = map(make_length_getter, ("width", "height"))
|
||||||
|
|
||||||
|
|
||||||
def merge_group(group, axis="y"):
|
|
||||||
|
|
||||||
group = list(group)
|
|
||||||
current_pair = group.pop(0)
|
|
||||||
for pair in group:
|
|
||||||
if y2_getter(current_pair) == y1_getter(pair):
|
|
||||||
current_box = merge_pair_vertically(current_pair, pair)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_metadata_horizontally(m1, m2):
|
def merge_metadata_horizontally(m1, m2):
|
||||||
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
|
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
|
||||||
return merge_metadata(m1, m2)
|
return merge_metadata(m1, m2)
|
||||||
@ -121,25 +112,73 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
|||||||
return im_aggr
|
return im_aggr
|
||||||
|
|
||||||
|
|
||||||
|
class Stitcher:
|
||||||
|
@staticmethod
|
||||||
|
def groupby(pairs, coord):
|
||||||
|
coord_getter = make_coord_getter(coord)
|
||||||
|
pairs = sorted(pairs, key=coord_getter)
|
||||||
|
return map(compose(list, second), groupby(pairs, coord_getter))
|
||||||
|
|
||||||
|
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
|
||||||
|
groups = self.groupby(pairs, "x1")
|
||||||
|
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
|
||||||
|
groups = map(partial(sorted, key=y1_getter), groups)
|
||||||
|
groups = map(merge_group, groups)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_group(group):
|
||||||
|
def f(pairs):
|
||||||
|
current_pair = pairs.pop(0)
|
||||||
|
|
||||||
|
to_remove = []
|
||||||
|
|
||||||
|
for pair in pairs:
|
||||||
|
if y2_getter(current_pair) == y1_getter(pair):
|
||||||
|
current_pair = merge_pair_vertically(current_pair, pair)
|
||||||
|
to_remove.append(pair)
|
||||||
|
|
||||||
|
return [current_pair, *filter(lambda p: p not in to_remove, pairs)]
|
||||||
|
|
||||||
|
pairs = list(group)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
new_pairs = f(deepcopy(pairs))
|
||||||
|
if len(new_pairs) == len(pairs):
|
||||||
|
break
|
||||||
|
pairs = new_pairs
|
||||||
|
|
||||||
|
return new_pairs
|
||||||
|
|
||||||
|
|
||||||
#####################################
|
#####################################
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_group(vertical_merge_test_pairs):
|
||||||
|
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
|
||||||
|
prs_merged = merge_group([pr1, pr2])
|
||||||
|
assert len(prs_merged) == 1
|
||||||
|
assert_pair_equal(prs_merged[0], pr_merged_expected)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_pair_equal(pr1, pr2):
|
||||||
|
assert pr1.metadata == pr2.metadata
|
||||||
|
assert images_equal(pr1.image, pr2.image)
|
||||||
|
|
||||||
|
|
||||||
def test_merge_pairs_horizontally(horizontal_merge_test_pairs):
|
def test_merge_pairs_horizontally(horizontal_merge_test_pairs):
|
||||||
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
|
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
|
||||||
pr_merged = merge_pair_horizontally(pr1, pr2)
|
pr_merged = merge_pair_horizontally(pr1, pr2)
|
||||||
assert pr_merged.metadata == pr_merged_expected.metadata
|
assert_pair_equal(pr_merged, pr_merged_expected)
|
||||||
images_equal(pr_merged.image, pr_merged_expected.image)
|
|
||||||
|
|
||||||
|
|
||||||
def test_merge_pairs_vertically(vertical_merge_test_pairs):
|
def test_merge_pairs_vertically(vertical_merge_test_pairs):
|
||||||
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
|
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
|
||||||
pr_merged = merge_pair_vertically(pr1, pr2)
|
pr_merged = merge_pair_vertically(pr1, pr2)
|
||||||
assert pr_merged.metadata == pr_merged_expected.metadata
|
assert_pair_equal(pr_merged, pr_merged_expected)
|
||||||
images_equal(pr_merged.image, pr_merged_expected.image)
|
|
||||||
|
|
||||||
|
|
||||||
def images_equal(im1: Image, im2: Image):
|
def images_equal(im1: Image, im2: Image):
|
||||||
assert np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2))
|
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -200,28 +239,15 @@ def test_concat_images_horizontally(horizontal_merge_test_metadata):
|
|||||||
im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged])
|
im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged])
|
||||||
im_merged = concat_images_horizontally(im1, im2, mdat_merged)
|
im_merged = concat_images_horizontally(im1, im2, mdat_merged)
|
||||||
assert im_merged.size == im_merged_expected.size
|
assert im_merged.size == im_merged_expected.size
|
||||||
images_equal(im_merged, im_merged_expected)
|
assert images_equal(im_merged, im_merged_expected)
|
||||||
|
|
||||||
|
|
||||||
def test_concat_images_vertically(vertical_merge_test_metadata):
|
def test_concat_images_vertically(vertical_merge_test_metadata):
|
||||||
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
|
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
|
||||||
im1, im2, im_merged_expected = map(random_single_color_image_from_metadata, [mdat1, mdat2, mdat_merged])
|
im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged])
|
||||||
im_merged = concat_images_vertically(im1, im2, mdat_merged)
|
im_merged = concat_images_vertically(im1, im2, mdat_merged)
|
||||||
assert im_merged.size == im_merged_expected.size
|
assert im_merged.size == im_merged_expected.size
|
||||||
|
assert images_equal(im_merged, im_merged_expected)
|
||||||
|
|
||||||
class Stitcher:
|
|
||||||
@staticmethod
|
|
||||||
def groupby(pairs, coord):
|
|
||||||
coord_getter = make_coord_getter(coord)
|
|
||||||
pairs = sorted(pairs, key=coord_getter)
|
|
||||||
return map(compose(list, second), groupby(pairs, coord_getter))
|
|
||||||
|
|
||||||
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
|
|
||||||
groups = self.groupby(pairs, "x1")
|
|
||||||
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
|
|
||||||
groups = map(partial(sorted, key=y1_getter), groups)
|
|
||||||
groups = map(merge_group, groups)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("width", [160])
|
@pytest.mark.parametrize("width", [160])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user