stitcher test passes
This commit is contained in:
parent
160973e2be
commit
8ac9fcb19f
@ -6,6 +6,7 @@ from typing import Iterable, List
|
||||
|
||||
import fpdf
|
||||
import numpy as np
|
||||
import pdf2image
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from funcy import merge, second, compose, rpartial, juxt, rest, first, one, iterate
|
||||
@ -117,24 +118,11 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
||||
return im_aggr
|
||||
|
||||
|
||||
class Stitcher:
|
||||
@staticmethod
|
||||
def groupby(pairs, coord):
|
||||
coord_getter = make_coord_getter(coord)
|
||||
pairs = sorted(pairs, key=coord_getter)
|
||||
return map(compose(list, second), groupby(pairs, coord_getter))
|
||||
|
||||
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
|
||||
groups = self.groupby(pairs, "x1")
|
||||
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
|
||||
groups = map(partial(sorted, key=y1_getter), groups)
|
||||
groups = map(merge_group, groups)
|
||||
|
||||
|
||||
def make_merger_aggregator(direction):
|
||||
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
||||
head H and aggregates non-adjacent in the tail T.
|
||||
"""
|
||||
|
||||
def merger_aggregator(pairs):
|
||||
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
||||
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
||||
@ -175,9 +163,50 @@ def merge_group_vertically(group):
|
||||
return merge_group(group, "y")
|
||||
|
||||
|
||||
class Stitcher:
|
||||
@staticmethod
|
||||
def groupby(pairs, coord):
|
||||
coord_getter = make_coord_getter(coord)
|
||||
pairs = sorted(pairs, key=coord_getter)
|
||||
return map(compose(list, second), groupby(pairs, coord_getter))
|
||||
|
||||
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
|
||||
|
||||
pairs = list(pairs)
|
||||
n = len(pairs)
|
||||
|
||||
while True:
|
||||
|
||||
groups = self.groupby(pairs, "x1")
|
||||
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
|
||||
groups = map(merge_group_vertically, groups)
|
||||
pairs = chain.from_iterable(groups)
|
||||
|
||||
groups = self.groupby(pairs, "y1")
|
||||
groups = chain.from_iterable(map(rpartial(self.groupby, "y2"), groups))
|
||||
groups = map(merge_group_horizontally, groups)
|
||||
pairs = list(chain.from_iterable(groups))
|
||||
|
||||
if len(pairs) == n:
|
||||
return pairs
|
||||
n = len(pairs)
|
||||
|
||||
|
||||
#####################################
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("width", [160])
|
||||
# @pytest.mark.parametrize("height", [90])
|
||||
# @pytest.mark.parametrize("page_width", [int(160 * 1.1)])
|
||||
# @pytest.mark.parametrize("page_height", [int(90 * 1.1)])
|
||||
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
|
||||
pair_stitched = Stitcher().stitch(patch_image_metadata_pairs)[0]
|
||||
assert pair_stitched.metadata == base_patch_metadata
|
||||
# pair_stitched.image.show()
|
||||
# base_patch_image.show()
|
||||
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
||||
|
||||
|
||||
def test_merge_group_horizontally(horizontal_merge_test_pairs):
|
||||
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
|
||||
|
||||
@ -230,8 +259,8 @@ def test_merge_pairs_vertically(vertical_merge_test_pairs):
|
||||
assert pair_equal(pr_merged, pr_merged_expected)
|
||||
|
||||
|
||||
def images_equal(im1: Image, im2: Image):
|
||||
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2))
|
||||
def images_equal(im1: Image, im2: Image, **kwargs):
|
||||
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -285,6 +314,11 @@ def merge_test_metadata(base_patch_metadata):
|
||||
return juxt(*repeat(deepcopy, 3))(base_patch_metadata)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_patch_image(stitch_test_pdf):
|
||||
return pdf2image.convert_from_bytes(stitch_test_pdf)[0]
|
||||
|
||||
|
||||
def test_concat_images_horizontally(horizontal_merge_test_metadata):
|
||||
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
|
||||
im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged])
|
||||
@ -301,28 +335,15 @@ def test_concat_images_vertically(vertical_merge_test_metadata):
|
||||
assert images_equal(im_merged, im_merged_expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("width", [160])
|
||||
@pytest.mark.parametrize("height", [90])
|
||||
@pytest.mark.parametrize("page_width", [int(160 * 1.1)])
|
||||
@pytest.mark.parametrize("page_height", [int(90 * 1.1)])
|
||||
@pytest.mark.skip()
|
||||
def test_image_stitcher(patches_metadata, base_patch_metadata):
|
||||
# noinspection PyTypeChecker
|
||||
assert Stitcher().stitch(patch_image_metadata_pairs).metadata == base_patch_metadata
|
||||
@pytest.fixture
|
||||
def stitch_test_pdf(patch_image_metadata_pairs, width, height):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("width", [160])
|
||||
@pytest.mark.parametrize("height", [90])
|
||||
@pytest.mark.parametrize("page_width", [int(160 * 1.1)])
|
||||
@pytest.mark.parametrize("page_height", [int(90 * 1.1)])
|
||||
def test_partial_image_metadata_pairs(patch_image_metadata_pairs, page_width, page_height):
|
||||
|
||||
pdf = fpdf.FPDF(unit="pt", format=(page_width, page_height))
|
||||
pdf = fpdf.FPDF(unit="pt", format=(width, height))
|
||||
|
||||
for pair in patch_image_metadata_pairs:
|
||||
add_image(pdf, pair)
|
||||
|
||||
pdf.output("/tmp/bla.pdf")
|
||||
return pdf.output(dest="S").encode("latin1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user