added img-mdat-pair merging logic

This commit is contained in:
Matthias Bisping 2022-04-07 16:11:12 +02:00
parent 5e8b55ef10
commit 2b1e7cbb08
2 changed files with 61 additions and 22 deletions

View File

@ -485,3 +485,8 @@ def random_single_color_image_from_metadata(metadata):
"RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=tuple(map(int, np.random.uniform(size=3) * 255)) "RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=tuple(map(int, np.random.uniform(size=3) * 255))
) )
return image return image
def random_size_gray_image_from_metadata(metadata):
image = Image.new("RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=(100, 100, 100))
return image

View File

@ -1,19 +1,25 @@
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from itertools import groupby
from itertools import starmap, chain, repeat from itertools import starmap, chain, repeat
from operator import itemgetter, attrgetter
from typing import Iterable, List from typing import Iterable, List
import fpdf import fpdf
import numpy as np
import pytest import pytest
from PIL import Image from PIL import Image
from funcy import merge, second, compose, rpartial, curry, juxt, project, omit, flip, identity from funcy import merge, second, compose, rpartial, juxt
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info from image_prediction.info import Info
from test.conftest import get_base_position_metadata, add_image, random_single_color_image_from_metadata from test.conftest import (
get_base_position_metadata,
add_image,
random_single_color_image_from_metadata,
random_size_gray_image_from_metadata,
)
from test.utils.stitching import BoxSplitter, VerticalKeyMapper, HorizontalKeyMapper from test.utils.stitching import BoxSplitter, VerticalKeyMapper, HorizontalKeyMapper
from itertools import groupby
def make_getter(key): def make_getter(key):
@ -77,19 +83,19 @@ def merge_metadata(m1, m2):
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair): def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
mdat_merged = merge_metadata_horizontally(p1.metadata, p2.metadata) metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged)
return ImageMetadataPair(image_concatenated, metadata_merged)
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair): def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
mdat_merged = merge_metadata_vertically(p1.metadata, p2.metadata) metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged)
return ImageMetadataPair(image_concatenated, metadata_merged)
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair): # def merge_pair(p1, p2):
pass # assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX]
def merge_pair(p1, p2):
assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX]
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict): def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
@ -109,7 +115,7 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
offsets = [0, *[im.size[axis] for im in images]] offsets = [0, *[im.size[axis] for im in images]]
for im, offset in zip(images, offsets): for im, offset in zip(images, offsets):
box = (offset, 0) if axis else (0, offset) box = (offset, 0) if not axis else (0, offset)
im_aggr.paste(im, box=box) im_aggr.paste(im, box=box)
return im_aggr return im_aggr
@ -118,6 +124,36 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
##################################### #####################################
def test_merge_pairs_horizontally(horizontal_merge_test_pairs):
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
pr_merged = merge_pair_horizontally(pr1, pr2)
assert pr_merged.metadata == pr_merged_expected.metadata
images_equal(pr_merged.image, pr_merged_expected.image)
def test_merge_pairs_vertically(vertical_merge_test_pairs):
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
pr_merged = merge_pair_vertically(pr1, pr2)
assert pr_merged.metadata == pr_merged_expected.metadata
images_equal(pr_merged.image, pr_merged_expected.image)
def images_equal(im1: Image, im2: Image):
assert np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2))
@pytest.fixture
def horizontal_merge_test_pairs(horizontal_merge_test_metadata):
images = map(random_size_gray_image_from_metadata, horizontal_merge_test_metadata)
return list(starmap(ImageMetadataPair, zip(images, horizontal_merge_test_metadata)))
@pytest.fixture
def vertical_merge_test_pairs(vertical_merge_test_metadata):
images = map(random_size_gray_image_from_metadata, vertical_merge_test_metadata)
return list(starmap(ImageMetadataPair, zip(images, vertical_merge_test_metadata)))
def test_merge_metadata_horizontally(horizontal_merge_test_metadata): def test_merge_metadata_horizontally(horizontal_merge_test_metadata):
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged
@ -159,21 +195,19 @@ def merge_test_metadata(base_patch_metadata):
return juxt(*repeat(deepcopy, 3))(base_patch_metadata) return juxt(*repeat(deepcopy, 3))(base_patch_metadata)
@pytest.fixture
def merge_test_image_metadata_pairs():
pass
def test_concat_images_horizontally(horizontal_merge_test_metadata): def test_concat_images_horizontally(horizontal_merge_test_metadata):
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
im1, im2, im_merged = map(random_single_color_image_from_metadata, [mdat1, mdat2, mdat_merged]) im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged])
assert concat_images_horizontally(im1, im2, mdat_merged).size == im_merged.size im_merged = concat_images_horizontally(im1, im2, mdat_merged)
assert im_merged.size == im_merged_expected.size
images_equal(im_merged, im_merged_expected)
def test_concat_images_vertically(vertical_merge_test_metadata): def test_concat_images_vertically(vertical_merge_test_metadata):
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
im1, im2, im_merged = map(random_single_color_image_from_metadata, [mdat1, mdat2, mdat_merged]) im1, im2, im_merged_expected = map(random_single_color_image_from_metadata, [mdat1, mdat2, mdat_merged])
assert concat_images_vertically(im1, im2, mdat_merged).size == im_merged.size im_merged = concat_images_vertically(im1, im2, mdat_merged)
assert im_merged.size == im_merged_expected.size
class Stitcher: class Stitcher: