image-classification-service/test/unit_tests/image_stitching_test.py
Matthias Bisping 3b18fc6158 refactoring
2022-04-08 13:56:57 +02:00

198 lines
7.0 KiB
Python

from copy import deepcopy
from functools import partial
from itertools import starmap, repeat
from typing import List
import fpdf
import numpy as np
import pdf2image
import pytest
from PIL import Image
from funcy import merge, juxt, one
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.stitching.stitching import stitch_pairs
from image_prediction.stitching.utils import (
make_coord_getter,
make_length_getter,
)
from image_prediction.stitching.merging import merge_metadata_horizontally, merge_metadata_vertically, \
merge_pair_horizontally, merge_pair_vertically, concat_images_horizontally, concat_images_vertically, \
merge_group_horizontally, merge_group_vertically
from test.conftest import (
get_base_position_metadata,
add_image,
random_single_color_image_from_metadata,
random_size_gray_image_from_metadata,
)
from test.utils.stitching import BoxSplitter
x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2"))
width_getter, height_getter = map(make_length_getter, ("width", "height"))
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
pair_stitched = stitch_pairs(patch_image_metadata_pairs)[0]
assert pair_stitched.metadata == base_patch_metadata
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
def test_merge_group_horizontally(horizontal_merge_test_pairs):
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
prs_merged = merge_group_horizontally([pr1, pr2])
assert len(prs_merged) == 1
assert pair_equal(prs_merged[0], pr_merged_expected)
mdat3 = deepcopy(pr2.metadata)
mdat3[Info.HEIGHT] += 30
mdat3[Info.Y2] += 30
im3 = random_size_gray_image_from_metadata(mdat3)
pr3 = ImageMetadataPair(im3, mdat3)
prs_merged = merge_group_horizontally([pr1, pr2, pr3])
assert len(prs_merged) == 2
assert one(partial(pair_equal, pr_merged_expected), prs_merged)
def test_merge_group_vertically(vertical_merge_test_pairs):
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
prs_merged = merge_group_vertically([pr1, pr2])
assert len(prs_merged) == 1
assert pair_equal(prs_merged[0], pr_merged_expected)
mdat3 = deepcopy(pr2.metadata)
mdat3[Info.WIDTH] += 30
mdat3[Info.X2] += 30
im3 = random_size_gray_image_from_metadata(mdat3)
pr3 = ImageMetadataPair(im3, mdat3)
prs_merged = merge_group_vertically([pr1, pr2, pr3])
assert len(prs_merged) == 2
assert one(partial(pair_equal, pr_merged_expected), prs_merged)
def pair_equal(pr1, pr2):
return pr1.metadata == pr2.metadata and images_equal(pr1.image, pr2.image)
def test_merge_pairs_horizontally(horizontal_merge_test_pairs):
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
pr_merged = merge_pair_horizontally(pr1, pr2)
assert pair_equal(pr_merged, pr_merged_expected)
def test_merge_pairs_vertically(vertical_merge_test_pairs):
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
pr_merged = merge_pair_vertically(pr1, pr2)
assert pair_equal(pr_merged, pr_merged_expected)
def images_equal(im1: Image, im2: Image, **kwargs):
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
@pytest.fixture
def horizontal_merge_test_pairs(horizontal_merge_test_metadata):
images = map(random_size_gray_image_from_metadata, horizontal_merge_test_metadata)
return list(starmap(ImageMetadataPair, zip(images, horizontal_merge_test_metadata)))
@pytest.fixture
def vertical_merge_test_pairs(vertical_merge_test_metadata):
images = map(random_size_gray_image_from_metadata, vertical_merge_test_metadata)
return list(starmap(ImageMetadataPair, zip(images, vertical_merge_test_metadata)))
def test_merge_metadata_horizontally(horizontal_merge_test_metadata):
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged
def test_merge_metadata_vertically(vertical_merge_test_metadata):
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged
@pytest.fixture
def horizontal_merge_test_metadata(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.X1] = mdat1[Info.X2]
mdat2[Info.X2] = mdat2[Info.X1] + mdat2[Info.WIDTH]
mdat_merged.update({Info.WIDTH: mdat1[Info.WIDTH] + mdat2[Info.WIDTH], Info.X2: mdat2[Info.X2]})
return mdat1, mdat2, mdat_merged
@pytest.fixture
def vertical_merge_test_metadata(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.Y1] = mdat1[Info.Y2]
mdat2[Info.Y2] = mdat2[Info.Y1] + mdat2[Info.HEIGHT]
mdat_merged.update({Info.HEIGHT: mdat1[Info.HEIGHT] + mdat2[Info.HEIGHT], Info.Y2: mdat2[Info.Y2]})
return mdat1, mdat2, mdat_merged
@pytest.fixture
def merge_test_metadata(base_patch_metadata):
return juxt(*repeat(deepcopy, 3))(base_patch_metadata)
@pytest.fixture
def base_patch_image(stitch_test_pdf):
return pdf2image.convert_from_bytes(stitch_test_pdf)[0]
def test_concat_images_horizontally(horizontal_merge_test_metadata):
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged])
im_merged = concat_images_horizontally(im1, im2, mdat_merged)
assert im_merged.size == im_merged_expected.size
assert images_equal(im_merged, im_merged_expected)
def test_concat_images_vertically(vertical_merge_test_metadata):
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged])
im_merged = concat_images_vertically(im1, im2, mdat_merged)
assert im_merged.size == im_merged_expected.size
assert images_equal(im_merged, im_merged_expected)
@pytest.fixture
def stitch_test_pdf(patch_image_metadata_pairs, width, height):
pdf = fpdf.FPDF(unit="pt", format=(width, height))
for pair in patch_image_metadata_pairs:
add_image(pdf, pair)
return pdf.output(dest="S").encode("latin1")
@pytest.fixture
def patch_image_metadata_pairs(patches_metadata) -> List[ImageMetadataPair]:
images = map(random_single_color_image_from_metadata, patches_metadata)
return list(starmap(ImageMetadataPair, zip(images, patches_metadata)))
@pytest.fixture
def base_patch_metadata(width, height, page_width, page_height):
metadata = get_base_position_metadata(width, height, page_width, page_height)
metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height})
return metadata
@pytest.fixture
def patches_metadata(base_patch_metadata):
patches_metadata = list(BoxSplitter().split_box(base_patch_metadata))
return patches_metadata