335 lines
11 KiB
Python
335 lines
11 KiB
Python
from copy import deepcopy
|
|
from functools import partial, reduce
|
|
from itertools import groupby
|
|
from itertools import starmap, chain, repeat
|
|
from typing import Iterable, List
|
|
|
|
import fpdf
|
|
import numpy as np
|
|
import pytest
|
|
from PIL import Image
|
|
from funcy import merge, second, compose, rpartial, juxt, rest, first, one, iterate
|
|
|
|
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
|
from image_prediction.info import Info
|
|
from image_prediction.utils import chunk_iterable
|
|
from test.conftest import (
|
|
get_base_position_metadata,
|
|
add_image,
|
|
random_single_color_image_from_metadata,
|
|
random_size_gray_image_from_metadata,
|
|
)
|
|
from test.utils.stitching import BoxSplitter, VerticalKeyMapper, HorizontalKeyMapper
|
|
|
|
|
|
def make_getter(key):
|
|
def getter(pair):
|
|
return pair.metadata[key]
|
|
|
|
return getter
|
|
|
|
|
|
def make_coord_getter(c):
|
|
return {
|
|
"x1": make_getter(Info.X1),
|
|
"x2": make_getter(Info.X2),
|
|
"y1": make_getter(Info.Y1),
|
|
"y2": make_getter(Info.Y2),
|
|
}[c]
|
|
|
|
|
|
def make_pair_merger(direction):
|
|
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[direction]
|
|
|
|
|
|
def make_length_getter(dim):
|
|
return {
|
|
"width": make_getter(Info.WIDTH),
|
|
"height": make_getter(Info.HEIGHT),
|
|
}[dim]
|
|
|
|
|
|
x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2"))
|
|
width_getter, height_getter = map(make_length_getter, ("width", "height"))
|
|
|
|
|
|
def merge_metadata_horizontally(m1, m2):
|
|
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
|
|
return merge_metadata(m1, m2)
|
|
|
|
|
|
def merge_metadata_vertically(m1, m2):
|
|
m1, m2 = map(VerticalKeyMapper, [m1, m2])
|
|
return merge_metadata(m1, m2)
|
|
|
|
|
|
def merge_metadata(m1, m2):
|
|
|
|
c1 = min(m1.c1, m2.c1)
|
|
c2 = max(m1.c2, m2.c2)
|
|
dim = m1.dim + m2.dim
|
|
|
|
merged = deepcopy(m1)
|
|
merged.dim = dim
|
|
merged.c1 = c1
|
|
merged.c2 = c2
|
|
|
|
return merged.wrapped
|
|
|
|
|
|
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
|
metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
|
|
image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged)
|
|
return ImageMetadataPair(image_concatenated, metadata_merged)
|
|
|
|
|
|
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
|
metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
|
|
image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged)
|
|
return ImageMetadataPair(image_concatenated, metadata_merged)
|
|
|
|
|
|
# def merge_pair(p1, p2):
|
|
# assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX]
|
|
|
|
|
|
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
|
|
return concat_images(im1, im2, metadata, 0)
|
|
|
|
|
|
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
|
|
return concat_images(im1, im2, metadata, 1)
|
|
|
|
|
|
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
|
|
|
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
|
|
|
|
images = [im1, im2]
|
|
|
|
offsets = [0, *[im.size[axis] for im in images]]
|
|
|
|
for im, offset in zip(images, offsets):
|
|
box = (offset, 0) if not axis else (0, offset)
|
|
im_aggr.paste(im, box=box)
|
|
|
|
return im_aggr
|
|
|
|
|
|
class Stitcher:
|
|
@staticmethod
|
|
def groupby(pairs, coord):
|
|
coord_getter = make_coord_getter(coord)
|
|
pairs = sorted(pairs, key=coord_getter)
|
|
return map(compose(list, second), groupby(pairs, coord_getter))
|
|
|
|
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
|
|
groups = self.groupby(pairs, "x1")
|
|
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
|
|
groups = map(partial(sorted, key=y1_getter), groups)
|
|
groups = map(merge_group, groups)
|
|
|
|
|
|
def merge_group(group, direction):
|
|
def merge_with(aggregation_pair, pairs):
|
|
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
|
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
|
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
|
if c2_getter(aggr) == c1_getter(pair):
|
|
aggr = pair_merger(aggr, pair)
|
|
return aggr, *non_aggr
|
|
else:
|
|
return aggr, pair, *non_aggr
|
|
|
|
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair]))
|
|
|
|
c1_getter = make_coord_getter(f"{direction}1")
|
|
c2_getter = make_coord_getter(f"{direction}2")
|
|
pair_merger = make_pair_merger(direction)
|
|
|
|
def reduce_group(group):
|
|
group_reduced = merge_with(*juxt(first, rest)(group))
|
|
return group_reduced
|
|
|
|
for g1, g2 in chunk_iterable(iterate(reduce_group, group), chunk_size=2):
|
|
if len(g1) == len(g2):
|
|
return g1
|
|
|
|
|
|
def merge_group_horizontally(group):
|
|
return merge_group(group, "x")
|
|
|
|
|
|
def merge_group_vertically(group):
|
|
return merge_group(group, "y")
|
|
|
|
|
|
#####################################
|
|
|
|
|
|
def test_merge_group_horizontally(horizontal_merge_test_pairs):
|
|
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
|
|
|
|
prs_merged = merge_group_horizontally([pr1, pr2])
|
|
assert len(prs_merged) == 1
|
|
assert pair_equal(prs_merged[0], pr_merged_expected)
|
|
|
|
mdat3 = deepcopy(pr2.metadata)
|
|
mdat3[Info.HEIGHT] += 30
|
|
mdat3[Info.Y2] += 30
|
|
im3 = random_size_gray_image_from_metadata(mdat3)
|
|
pr3 = ImageMetadataPair(im3, mdat3)
|
|
|
|
prs_merged = merge_group_horizontally([pr1, pr2, pr3])
|
|
assert len(prs_merged) == 2
|
|
assert one(partial(pair_equal, pr_merged_expected), prs_merged)
|
|
|
|
|
|
def test_merge_group_vertically(vertical_merge_test_pairs):
|
|
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
|
|
|
|
prs_merged = merge_group_vertically([pr1, pr2])
|
|
assert len(prs_merged) == 1
|
|
assert pair_equal(prs_merged[0], pr_merged_expected)
|
|
|
|
mdat3 = deepcopy(pr2.metadata)
|
|
mdat3[Info.WIDTH] += 30
|
|
mdat3[Info.X2] += 30
|
|
im3 = random_size_gray_image_from_metadata(mdat3)
|
|
pr3 = ImageMetadataPair(im3, mdat3)
|
|
|
|
prs_merged = merge_group_vertically([pr1, pr2, pr3])
|
|
assert len(prs_merged) == 2
|
|
assert one(partial(pair_equal, pr_merged_expected), prs_merged)
|
|
|
|
|
|
def pair_equal(pr1, pr2):
|
|
return pr1.metadata == pr2.metadata and images_equal(pr1.image, pr2.image)
|
|
|
|
|
|
def test_merge_pairs_horizontally(horizontal_merge_test_pairs):
|
|
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
|
|
pr_merged = merge_pair_horizontally(pr1, pr2)
|
|
assert pair_equal(pr_merged, pr_merged_expected)
|
|
|
|
|
|
def test_merge_pairs_vertically(vertical_merge_test_pairs):
|
|
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
|
|
pr_merged = merge_pair_vertically(pr1, pr2)
|
|
assert pair_equal(pr_merged, pr_merged_expected)
|
|
|
|
|
|
def images_equal(im1: Image, im2: Image):
|
|
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2))
|
|
|
|
|
|
@pytest.fixture
|
|
def horizontal_merge_test_pairs(horizontal_merge_test_metadata):
|
|
images = map(random_size_gray_image_from_metadata, horizontal_merge_test_metadata)
|
|
return list(starmap(ImageMetadataPair, zip(images, horizontal_merge_test_metadata)))
|
|
|
|
|
|
@pytest.fixture
|
|
def vertical_merge_test_pairs(vertical_merge_test_metadata):
|
|
images = map(random_size_gray_image_from_metadata, vertical_merge_test_metadata)
|
|
return list(starmap(ImageMetadataPair, zip(images, vertical_merge_test_metadata)))
|
|
|
|
|
|
def test_merge_metadata_horizontally(horizontal_merge_test_metadata):
|
|
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
|
|
assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged
|
|
|
|
|
|
def test_merge_metadata_vertically(vertical_merge_test_metadata):
|
|
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
|
|
assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged
|
|
|
|
|
|
@pytest.fixture
|
|
def horizontal_merge_test_metadata(merge_test_metadata):
|
|
mdat1, mdat2, mdat_merged = merge_test_metadata
|
|
|
|
mdat2[Info.X1] = mdat1[Info.X2]
|
|
mdat2[Info.X2] = mdat2[Info.X1] + mdat2[Info.WIDTH]
|
|
|
|
mdat_merged.update({Info.WIDTH: mdat1[Info.WIDTH] + mdat2[Info.WIDTH], Info.X2: mdat2[Info.X2]})
|
|
|
|
return mdat1, mdat2, mdat_merged
|
|
|
|
|
|
@pytest.fixture
|
|
def vertical_merge_test_metadata(merge_test_metadata):
|
|
mdat1, mdat2, mdat_merged = merge_test_metadata
|
|
|
|
mdat2[Info.Y1] = mdat1[Info.Y2]
|
|
mdat2[Info.Y2] = mdat2[Info.Y1] + mdat2[Info.HEIGHT]
|
|
|
|
mdat_merged.update({Info.HEIGHT: mdat1[Info.HEIGHT] + mdat2[Info.HEIGHT], Info.Y2: mdat2[Info.Y2]})
|
|
|
|
return mdat1, mdat2, mdat_merged
|
|
|
|
|
|
@pytest.fixture
|
|
def merge_test_metadata(base_patch_metadata):
|
|
return juxt(*repeat(deepcopy, 3))(base_patch_metadata)
|
|
|
|
|
|
def test_concat_images_horizontally(horizontal_merge_test_metadata):
|
|
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
|
|
im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged])
|
|
im_merged = concat_images_horizontally(im1, im2, mdat_merged)
|
|
assert im_merged.size == im_merged_expected.size
|
|
assert images_equal(im_merged, im_merged_expected)
|
|
|
|
|
|
def test_concat_images_vertically(vertical_merge_test_metadata):
|
|
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
|
|
im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged])
|
|
im_merged = concat_images_vertically(im1, im2, mdat_merged)
|
|
assert im_merged.size == im_merged_expected.size
|
|
assert images_equal(im_merged, im_merged_expected)
|
|
|
|
|
|
@pytest.mark.parametrize("width", [160])
|
|
@pytest.mark.parametrize("height", [90])
|
|
@pytest.mark.parametrize("page_width", [int(160 * 1.1)])
|
|
@pytest.mark.parametrize("page_height", [int(90 * 1.1)])
|
|
@pytest.mark.skip()
|
|
def test_image_stitcher(patches_metadata, base_patch_metadata):
|
|
# noinspection PyTypeChecker
|
|
assert Stitcher().stitch(patch_image_metadata_pairs).metadata == base_patch_metadata
|
|
|
|
|
|
@pytest.mark.parametrize("width", [160])
|
|
@pytest.mark.parametrize("height", [90])
|
|
@pytest.mark.parametrize("page_width", [int(160 * 1.1)])
|
|
@pytest.mark.parametrize("page_height", [int(90 * 1.1)])
|
|
def test_partial_image_metadata_pairs(patch_image_metadata_pairs, page_width, page_height):
|
|
|
|
pdf = fpdf.FPDF(unit="pt", format=(page_width, page_height))
|
|
|
|
for pair in patch_image_metadata_pairs:
|
|
add_image(pdf, pair)
|
|
|
|
pdf.output("/tmp/bla.pdf")
|
|
|
|
|
|
@pytest.fixture
|
|
def patch_image_metadata_pairs(patches_metadata) -> List[ImageMetadataPair]:
|
|
images = map(random_single_color_image_from_metadata, patches_metadata)
|
|
return list(starmap(ImageMetadataPair, zip(images, patches_metadata)))
|
|
|
|
|
|
@pytest.fixture
|
|
def base_patch_metadata(width, height, page_width, page_height):
|
|
metadata = get_base_position_metadata(width, height, page_width, page_height)
|
|
metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height})
|
|
return metadata
|
|
|
|
|
|
@pytest.fixture
|
|
def patches_metadata(base_patch_metadata):
|
|
patches_metadata = list(BoxSplitter().split_box(base_patch_metadata))
|
|
return patches_metadata
|