added image concatenation; refactoring

This commit is contained in:
Matthias Bisping 2022-04-07 11:42:38 +02:00
parent 3266e0af58
commit 5e8b55ef10

View File

@ -1,12 +1,13 @@
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from itertools import starmap, chain, repeat from itertools import starmap, chain, repeat
from operator import itemgetter from operator import itemgetter, attrgetter
from typing import Iterable, List from typing import Iterable, List
import fpdf import fpdf
import pytest import pytest
from funcy import merge, second, compose, rpartial, curry, juxt, project, omit from PIL import Image
from funcy import merge, second, compose, rpartial, curry, juxt, project, omit, flip, identity
from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info from image_prediction.info import Info
@ -42,13 +43,13 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1",
width_getter, height_getter = map(make_length_getter, ("width", "height")) width_getter, height_getter = map(make_length_getter, ("width", "height"))
# def merge_group(group, axis="y"): def merge_group(group, axis="y"):
#
# group = list(group) group = list(group)
# current_pair = group.pop(0) current_pair = group.pop(0)
# for pair in group: for pair in group:
# if y2_getter(current_pair) == y1_getter(pair): if y2_getter(current_pair) == y1_getter(pair):
# current_box = merge_pair(current_pair, pair) current_box = merge_pair_vertically(current_pair, pair)
def merge_metadata_horizontally(m1, m2): def merge_metadata_horizontally(m1, m2):
@ -79,12 +80,57 @@ def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
mdat_merged = merge_metadata_horizontally(p1.metadata, p2.metadata) mdat_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
# def merge_pair(p1, p2): def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
# mdat_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
# assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX]
def test_merge_metadata_horizontally(merge_test_metadata): def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
pass
def merge_pair(p1, p2):
assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX]
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
return concat_images(im1, im2, metadata, 0)
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
return concat_images(im1, im2, metadata, 1)
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
images = [im1, im2]
offsets = [0, *[im.size[axis] for im in images]]
for im, offset in zip(images, offsets):
box = (offset, 0) if axis else (0, offset)
im_aggr.paste(im, box=box)
return im_aggr
#####################################
def test_merge_metadata_horizontally(horizontal_merge_test_metadata):
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged
def test_merge_metadata_vertically(vertical_merge_test_metadata):
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged
@pytest.fixture
def horizontal_merge_test_metadata(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.X1] = mdat1[Info.X2] mdat2[Info.X1] = mdat1[Info.X2]
@ -92,10 +138,12 @@ def test_merge_metadata_horizontally(merge_test_metadata):
mdat_merged.update({Info.WIDTH: mdat1[Info.WIDTH] + mdat2[Info.WIDTH], Info.X2: mdat2[Info.X2]}) mdat_merged.update({Info.WIDTH: mdat1[Info.WIDTH] + mdat2[Info.WIDTH], Info.X2: mdat2[Info.X2]})
assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged return mdat1, mdat2, mdat_merged
def test_merge_metadata_vertically(merge_test_metadata): @pytest.fixture
def vertical_merge_test_metadata(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.Y1] = mdat1[Info.Y2] mdat2[Info.Y1] = mdat1[Info.Y2]
@ -103,27 +151,43 @@ def test_merge_metadata_vertically(merge_test_metadata):
mdat_merged.update({Info.HEIGHT: mdat1[Info.HEIGHT] + mdat2[Info.HEIGHT], Info.Y2: mdat2[Info.Y2]}) mdat_merged.update({Info.HEIGHT: mdat1[Info.HEIGHT] + mdat2[Info.HEIGHT], Info.Y2: mdat2[Info.Y2]})
assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged return mdat1, mdat2, mdat_merged
@pytest.fixture @pytest.fixture
def merge_test_metadata(base_patch_metadata): def merge_test_metadata(base_patch_metadata):
return juxt(*repeat(deepcopy, 3))(base_patch_metadata) return juxt(*repeat(deepcopy, 3))(base_patch_metadata)
@pytest.fixture
def merge_test_image_metadata_pairs():
pass
# class Stitcher:
# @staticmethod def test_concat_images_horizontally(horizontal_merge_test_metadata):
# def groupby(pairs, coord): mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
# coord_getter = make_coord_getter(coord) im1, im2, im_merged = map(random_single_color_image_from_metadata, [mdat1, mdat2, mdat_merged])
# pairs = sorted(pairs, key=coord_getter) assert concat_images_horizontally(im1, im2, mdat_merged).size == im_merged.size
# return map(compose(list, second), groupby(pairs, coord_getter))
#
# def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair: def test_concat_images_vertically(vertical_merge_test_metadata):
# groups = self.groupby(pairs, "x1") mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
# groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups)) im1, im2, im_merged = map(random_single_color_image_from_metadata, [mdat1, mdat2, mdat_merged])
# groups = map(partial(sorted, key=y1_getter), groups) assert concat_images_vertically(im1, im2, mdat_merged).size == im_merged.size
# groups = map(merge_group, groups)
class Stitcher:
@staticmethod
def groupby(pairs, coord):
coord_getter = make_coord_getter(coord)
pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter))
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
groups = self.groupby(pairs, "x1")
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
groups = map(partial(sorted, key=y1_getter), groups)
groups = map(merge_group, groups)
@pytest.mark.parametrize("width", [160]) @pytest.mark.parametrize("width", [160])