added image concatenation; refactoring

This commit is contained in:
Matthias Bisping 2022-04-07 11:42:38 +02:00
parent 3266e0af58
commit 5e8b55ef10

View File

@ -1,12 +1,13 @@
from copy import deepcopy
from functools import partial
from itertools import starmap, chain, repeat
from operator import itemgetter
from operator import itemgetter, attrgetter
from typing import Iterable, List
import fpdf
import pytest
from funcy import merge, second, compose, rpartial, curry, juxt, project, omit
from PIL import Image
from funcy import merge, second, compose, rpartial, curry, juxt, project, omit, flip, identity
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
@ -42,13 +43,13 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1",
width_getter, height_getter = map(make_length_getter, ("width", "height"))
# def merge_group(group, axis="y"):
#
# group = list(group)
# current_pair = group.pop(0)
# for pair in group:
# if y2_getter(current_pair) == y1_getter(pair):
# current_box = merge_pair(current_pair, pair)
def merge_group(group, axis="y"):
group = list(group)
current_pair = group.pop(0)
for pair in group:
if y2_getter(current_pair) == y1_getter(pair):
current_box = merge_pair_vertically(current_pair, pair)
def merge_metadata_horizontally(m1, m2):
@ -79,12 +80,57 @@ def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
mdat_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
# def merge_pair(p1, p2):
#
# assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX]
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
mdat_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
def test_merge_metadata_horizontally(merge_test_metadata):
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
pass
def merge_pair(p1, p2):
assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX]
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
return concat_images(im1, im2, metadata, 0)
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
return concat_images(im1, im2, metadata, 1)
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
images = [im1, im2]
offsets = [0, *[im.size[axis] for im in images]]
for im, offset in zip(images, offsets):
box = (offset, 0) if axis else (0, offset)
im_aggr.paste(im, box=box)
return im_aggr
#####################################
def test_merge_metadata_horizontally(horizontal_merge_test_metadata):
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged
def test_merge_metadata_vertically(vertical_merge_test_metadata):
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged
@pytest.fixture
def horizontal_merge_test_metadata(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.X1] = mdat1[Info.X2]
@ -92,10 +138,12 @@ def test_merge_metadata_horizontally(merge_test_metadata):
mdat_merged.update({Info.WIDTH: mdat1[Info.WIDTH] + mdat2[Info.WIDTH], Info.X2: mdat2[Info.X2]})
assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged
return mdat1, mdat2, mdat_merged
def test_merge_metadata_vertically(merge_test_metadata):
@pytest.fixture
def vertical_merge_test_metadata(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.Y1] = mdat1[Info.Y2]
@ -103,27 +151,43 @@ def test_merge_metadata_vertically(merge_test_metadata):
mdat_merged.update({Info.HEIGHT: mdat1[Info.HEIGHT] + mdat2[Info.HEIGHT], Info.Y2: mdat2[Info.Y2]})
assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged
return mdat1, mdat2, mdat_merged
@pytest.fixture
def merge_test_metadata(base_patch_metadata):
return juxt(*repeat(deepcopy, 3))(base_patch_metadata)
return juxt(*repeat(deepcopy, 3))(base_patch_metadata)
@pytest.fixture
def merge_test_image_metadata_pairs():
pass
# class Stitcher:
# @staticmethod
# def groupby(pairs, coord):
# coord_getter = make_coord_getter(coord)
# pairs = sorted(pairs, key=coord_getter)
# return map(compose(list, second), groupby(pairs, coord_getter))
#
# def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
# groups = self.groupby(pairs, "x1")
# groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
# groups = map(partial(sorted, key=y1_getter), groups)
# groups = map(merge_group, groups)
def test_concat_images_horizontally(horizontal_merge_test_metadata):
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
im1, im2, im_merged = map(random_single_color_image_from_metadata, [mdat1, mdat2, mdat_merged])
assert concat_images_horizontally(im1, im2, mdat_merged).size == im_merged.size
def test_concat_images_vertically(vertical_merge_test_metadata):
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
im1, im2, im_merged = map(random_single_color_image_from_metadata, [mdat1, mdat2, mdat_merged])
assert concat_images_vertically(im1, im2, mdat_merged).size == im_merged.size
class Stitcher:
@staticmethod
def groupby(pairs, coord):
coord_getter = make_coord_getter(coord)
pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter))
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
groups = self.groupby(pairs, "x1")
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
groups = map(partial(sorted, key=y1_getter), groups)
groups = map(merge_group, groups)
@pytest.mark.parametrize("width", [160])