group merging done

This commit is contained in:
Matthias Bisping 2022-04-07 18:05:15 +02:00
parent 9bb07f95fb
commit 50b4d239cb

View File

@ -8,7 +8,7 @@ import fpdf
import numpy as np
import pytest
from PIL import Image
from funcy import merge, second, compose, rpartial, juxt, rest, first
from funcy import merge, second, compose, rpartial, juxt, rest, first, one
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
from image_prediction.image_extractor.extractor import ImageMetadataPair
@ -39,10 +39,7 @@ def make_coord_getter(c):
def make_pair_merger(direction):
return {
"y": merge_pair_vertically,
"x": merge_pair_horizontally
}[direction]
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[direction]
def make_length_getter(dim):
@ -133,10 +130,8 @@ class Stitcher:
groups = map(merge_group, groups)
def merge_group(group, direction="y"):
def merge_group(group, direction):
def merge_with(aggregation_pair, pairs):
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
@ -163,31 +158,67 @@ def merge_group(group, direction="y"):
return new_pairs
def merge_group_horizontally(group):
return merge_group(group, "x")
def merge_group_vertically(group):
return merge_group(group, "y")
#####################################
def test_merge_group(vertical_merge_test_pairs):
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
prs_merged = merge_group([pr1, pr2])
def test_merge_group_horizontally(horizontal_merge_test_pairs):
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
prs_merged = merge_group_horizontally([pr1, pr2])
assert len(prs_merged) == 1
assert_pair_equal(prs_merged[0], pr_merged_expected)
assert pair_equal(prs_merged[0], pr_merged_expected)
mdat3 = deepcopy(pr2.metadata)
mdat3[Info.HEIGHT] += 30
mdat3[Info.Y2] += 30
im3 = random_size_gray_image_from_metadata(mdat3)
pr3 = ImageMetadataPair(im3, mdat3)
prs_merged = merge_group_horizontally([pr1, pr2, pr3])
assert len(prs_merged) == 2
assert one(partial(pair_equal, pr_merged_expected), prs_merged)
def assert_pair_equal(pr1, pr2):
assert pr1.metadata == pr2.metadata
assert images_equal(pr1.image, pr2.image)
def test_merge_group_vertically(vertical_merge_test_pairs):
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
prs_merged = merge_group_vertically([pr1, pr2])
assert len(prs_merged) == 1
assert pair_equal(prs_merged[0], pr_merged_expected)
mdat3 = deepcopy(pr2.metadata)
mdat3[Info.WIDTH] += 30
mdat3[Info.X2] += 30
im3 = random_size_gray_image_from_metadata(mdat3)
pr3 = ImageMetadataPair(im3, mdat3)
prs_merged = merge_group_vertically([pr1, pr2, pr3])
assert len(prs_merged) == 2
assert one(partial(pair_equal, pr_merged_expected), prs_merged)
def pair_equal(pr1, pr2):
return pr1.metadata == pr2.metadata and images_equal(pr1.image, pr2.image)
def test_merge_pairs_horizontally(horizontal_merge_test_pairs):
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
pr_merged = merge_pair_horizontally(pr1, pr2)
assert_pair_equal(pr_merged, pr_merged_expected)
assert pair_equal(pr_merged, pr_merged_expected)
def test_merge_pairs_vertically(vertical_merge_test_pairs):
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
pr_merged = merge_pair_vertically(pr1, pr2)
assert_pair_equal(pr_merged, pr_merged_expected)
assert pair_equal(pr_merged, pr_merged_expected)
def images_equal(im1: Image, im2: Image):
@ -218,7 +249,6 @@ def test_merge_metadata_vertically(vertical_merge_test_metadata):
@pytest.fixture
def horizontal_merge_test_metadata(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.X1] = mdat1[Info.X2]
@ -231,7 +261,6 @@ def horizontal_merge_test_metadata(merge_test_metadata):
@pytest.fixture
def vertical_merge_test_metadata(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.Y1] = mdat1[Info.Y2]