2022-04-07 11:42:38 +02:00

234 lines
7.0 KiB
Python

from copy import deepcopy
from functools import partial
from itertools import starmap, chain, repeat
from operator import itemgetter, attrgetter
from typing import Iterable, List
import fpdf
import pytest
from PIL import Image
from funcy import merge, second, compose, rpartial, curry, juxt, project, omit, flip, identity
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from test.conftest import get_base_position_metadata, add_image, random_single_color_image_from_metadata
from test.utils.stitching import BoxSplitter, VerticalKeyMapper, HorizontalKeyMapper
from itertools import groupby
def make_getter(key):
def getter(pair):
return pair.metadata[key]
return getter
def make_coord_getter(c):
return {
"x1": make_getter(Info.X1),
"x2": make_getter(Info.X2),
"y1": make_getter(Info.Y1),
"y2": make_getter(Info.Y2),
}[c]
def make_length_getter(dim):
return {
"width": make_getter(Info.WIDTH),
"height": make_getter(Info.HEIGHT),
}[dim]
x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2"))
width_getter, height_getter = map(make_length_getter, ("width", "height"))
def merge_group(group, axis="y"):
group = list(group)
current_pair = group.pop(0)
for pair in group:
if y2_getter(current_pair) == y1_getter(pair):
current_box = merge_pair_vertically(current_pair, pair)
def merge_metadata_horizontally(m1, m2):
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata_vertically(m1, m2):
m1, m2 = map(VerticalKeyMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata(m1, m2):
c1 = min(m1.c1, m2.c1)
c2 = max(m1.c2, m2.c2)
dim = m1.dim + m2.dim
merged = deepcopy(m1)
merged.dim = dim
merged.c1 = c1
merged.c2 = c2
return merged.wrapped
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
mdat_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
mdat_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
pass
def merge_pair(p1, p2):
assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX]
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
return concat_images(im1, im2, metadata, 0)
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
return concat_images(im1, im2, metadata, 1)
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
images = [im1, im2]
offsets = [0, *[im.size[axis] for im in images]]
for im, offset in zip(images, offsets):
box = (offset, 0) if axis else (0, offset)
im_aggr.paste(im, box=box)
return im_aggr
#####################################
def test_merge_metadata_horizontally(horizontal_merge_test_metadata):
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged
def test_merge_metadata_vertically(vertical_merge_test_metadata):
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged
@pytest.fixture
def horizontal_merge_test_metadata(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.X1] = mdat1[Info.X2]
mdat2[Info.X2] = mdat2[Info.X1] + mdat2[Info.WIDTH]
mdat_merged.update({Info.WIDTH: mdat1[Info.WIDTH] + mdat2[Info.WIDTH], Info.X2: mdat2[Info.X2]})
return mdat1, mdat2, mdat_merged
@pytest.fixture
def vertical_merge_test_metadata(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.Y1] = mdat1[Info.Y2]
mdat2[Info.Y2] = mdat2[Info.Y1] + mdat2[Info.HEIGHT]
mdat_merged.update({Info.HEIGHT: mdat1[Info.HEIGHT] + mdat2[Info.HEIGHT], Info.Y2: mdat2[Info.Y2]})
return mdat1, mdat2, mdat_merged
@pytest.fixture
def merge_test_metadata(base_patch_metadata):
return juxt(*repeat(deepcopy, 3))(base_patch_metadata)
@pytest.fixture
def merge_test_image_metadata_pairs():
pass
def test_concat_images_horizontally(horizontal_merge_test_metadata):
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
im1, im2, im_merged = map(random_single_color_image_from_metadata, [mdat1, mdat2, mdat_merged])
assert concat_images_horizontally(im1, im2, mdat_merged).size == im_merged.size
def test_concat_images_vertically(vertical_merge_test_metadata):
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
im1, im2, im_merged = map(random_single_color_image_from_metadata, [mdat1, mdat2, mdat_merged])
assert concat_images_vertically(im1, im2, mdat_merged).size == im_merged.size
class Stitcher:
@staticmethod
def groupby(pairs, coord):
coord_getter = make_coord_getter(coord)
pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter))
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
groups = self.groupby(pairs, "x1")
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
groups = map(partial(sorted, key=y1_getter), groups)
groups = map(merge_group, groups)
@pytest.mark.parametrize("width", [160])
@pytest.mark.parametrize("height", [90])
@pytest.mark.parametrize("page_width", [int(160 * 1.1)])
@pytest.mark.parametrize("page_height", [int(90 * 1.1)])
@pytest.mark.skip()
def test_image_stitcher(patches_metadata, base_patch_metadata):
# noinspection PyTypeChecker
assert Stitcher().stitch(patch_image_metadata_pairs).metadata == base_patch_metadata
@pytest.mark.parametrize("width", [160])
@pytest.mark.parametrize("height", [90])
@pytest.mark.parametrize("page_width", [int(160 * 1.1)])
@pytest.mark.parametrize("page_height", [int(90 * 1.1)])
def test_partial_image_metadata_pairs(patch_image_metadata_pairs, page_width, page_height):
pdf = fpdf.FPDF(unit="pt", format=(page_width, page_height))
for pair in patch_image_metadata_pairs:
add_image(pdf, pair)
pdf.output("/tmp/bla.pdf")
@pytest.fixture
def patch_image_metadata_pairs(patches_metadata) -> List[ImageMetadataPair]:
images = map(random_single_color_image_from_metadata, patches_metadata)
return list(starmap(ImageMetadataPair, zip(images, patches_metadata)))
@pytest.fixture
def base_patch_metadata(width, height, page_width, page_height):
metadata = get_base_position_metadata(width, height, page_width, page_height)
metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height})
return metadata
@pytest.fixture
def patches_metadata(base_patch_metadata):
patches_metadata = list(BoxSplitter().split_box(base_patch_metadata))
return patches_metadata