image-classification-service/test/unit_tests/image_stitching_test.py
2022-04-12 15:04:32 +02:00

257 lines
9.1 KiB
Python

import json
import os
from copy import deepcopy
from copy import deepcopy
from functools import partial
from itertools import starmap, repeat
from operator import itemgetter
from typing import List
import fpdf
import numpy as np
import pdf2image
import pytest
from PIL import Image
from funcy import juxt, one, first
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
from image_prediction.formatter.formatters.enum import ReverseEnumFormatter
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.stitching.grouping import group_by_coordinate
from image_prediction.stitching.merging import (
merge_metadata_horizontally,
merge_metadata_vertically,
merge_pair_horizontally,
merge_pair_vertically,
concat_images_horizontally,
concat_images_vertically,
merge_group_horizontally,
merge_group_vertically,
)
from image_prediction.stitching.stitching import stitch_pairs
from image_prediction.stitching.utils import (
make_coord_getter,
make_length_getter,
)
from test.conftest import (
add_image,
random_single_color_image_from_metadata,
random_size_gray_image_from_metadata,
)
from test.utils.stitching import BoxSplitter
x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2"))
width_getter, height_getter = map(make_length_getter, ("width", "height"))
def test_group_by_coordinate_exact():
pairs = [(0, 1), (0, 3), (1, 4), (1, 4), (1, 2), (3, 3)]
pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=0))
assert pairs_grouped == [[(0, 1), (0, 3)], [(1, 4), (1, 4), (1, 2)], [(3, 3)]]
def test_group_by_coordinate_fuzzy():
pairs = [(0, 1), (1, 3), (1, 4), (2, 4), (2, 2), (3, 3)]
pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=1))
assert pairs_grouped == [[(0, 1), (1, 3), (1, 4)], [(2, 4), (2, 2), (3, 3)]]
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
pairs_stitched = stitch_pairs(patch_image_metadata_pairs)
pair_stitched = first(pairs_stitched)
assert len(pairs_stitched) == 1
assert pair_stitched.metadata == base_patch_metadata
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
def test_image_stitcher_with_gaps_must_succeed():
from image_prediction.locations import TEST_DATA_DIR
with open(os.path.join(TEST_DATA_DIR, "stitching_with_tolerance.json")) as f:
patches_metadata, base_patch_metadata = itemgetter("input", "target")(ReverseEnumFormatter(Info)(json.load(f)))
images = map(random_size_gray_image_from_metadata, patches_metadata)
patch_image_metadata_pairs = list(starmap(ImageMetadataPair, zip(images, patches_metadata)))
pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=7)
assert len(pairs_stitched) == 1
pair_stitched = first(pairs_stitched)
assert pair_stitched.metadata == base_patch_metadata
@pytest.mark.parametrize("noise", [(0, 2)])
@pytest.mark.parametrize("split_count", [5])
@pytest.mark.parametrize("width", [100])
@pytest.mark.parametrize("height", [100])
@pytest.mark.parametrize("page_width", [100])
@pytest.mark.parametrize("page_height", [100])
@pytest.mark.parametrize("execution_number", range(100))
@pytest.mark.xfail(reason="Does not always succeed due to locally maximizing merging logic.")
def test_image_stitcher_with_gaps_can_fail(patch_image_metadata_pairs, base_patch_metadata, execution_number):
pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=4)
assert len(pairs_stitched) == 1 and first(pairs_stitched).metadata == base_patch_metadata
def test_merge_group_horizontally(horizontal_merge_test_pairs):
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
prs_merged = merge_group_horizontally([pr1, pr2])
assert len(prs_merged) == 1
assert pair_equal(prs_merged[0], pr_merged_expected)
mdat3 = deepcopy(pr2.metadata)
mdat3[Info.HEIGHT] += 30
mdat3[Info.Y2] += 30
im3 = random_size_gray_image_from_metadata(mdat3)
pr3 = ImageMetadataPair(im3, mdat3)
prs_merged = merge_group_horizontally([pr1, pr2, pr3])
assert len(prs_merged) == 2
assert one(partial(pair_equal, pr_merged_expected), prs_merged)
def test_merge_group_vertically(vertical_merge_test_pairs):
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
prs_merged = merge_group_vertically([pr1, pr2])
assert len(prs_merged) == 1
assert pair_equal(prs_merged[0], pr_merged_expected)
mdat3 = deepcopy(pr2.metadata)
mdat3[Info.WIDTH] += 30
mdat3[Info.X2] += 30
im3 = random_size_gray_image_from_metadata(mdat3)
pr3 = ImageMetadataPair(im3, mdat3)
prs_merged = merge_group_vertically([pr1, pr2, pr3])
assert len(prs_merged) == 2
assert one(partial(pair_equal, pr_merged_expected), prs_merged)
def pair_equal(pr1, pr2):
return pr1.metadata == pr2.metadata and images_equal(pr1.image, pr2.image)
def test_merge_pairs_horizontally(horizontal_merge_test_pairs):
pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs
pr_merged = merge_pair_horizontally(pr1, pr2)
assert pair_equal(pr_merged, pr_merged_expected)
def test_merge_pairs_vertically(vertical_merge_test_pairs):
pr1, pr2, pr_merged_expected = vertical_merge_test_pairs
pr_merged = merge_pair_vertically(pr1, pr2)
assert pair_equal(pr_merged, pr_merged_expected)
def images_equal(im1: Image, im2: Image, **kwargs):
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
@pytest.fixture
def horizontal_merge_test_pairs(horizontal_merge_test_metadata):
images = map(random_size_gray_image_from_metadata, horizontal_merge_test_metadata)
return list(starmap(ImageMetadataPair, zip(images, horizontal_merge_test_metadata)))
@pytest.fixture
def vertical_merge_test_pairs(vertical_merge_test_metadata):
images = map(random_size_gray_image_from_metadata, vertical_merge_test_metadata)
return list(starmap(ImageMetadataPair, zip(images, vertical_merge_test_metadata)))
def test_merge_metadata_horizontally(horizontal_merge_test_metadata):
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged
def test_merge_metadata_vertically(vertical_merge_test_metadata):
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged
@pytest.fixture
def horizontal_merge_test_metadata(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.X1] = mdat1[Info.X2]
mdat2[Info.X2] = mdat2[Info.X1] + mdat2[Info.WIDTH]
mdat_merged.update({Info.WIDTH: mdat1[Info.WIDTH] + mdat2[Info.WIDTH], Info.X2: mdat2[Info.X2]})
return mdat1, mdat2, mdat_merged
@pytest.fixture
def vertical_merge_test_metadata(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.Y1] = mdat1[Info.Y2]
mdat2[Info.Y2] = mdat2[Info.Y1] + mdat2[Info.HEIGHT]
mdat_merged.update({Info.HEIGHT: mdat1[Info.HEIGHT] + mdat2[Info.HEIGHT], Info.Y2: mdat2[Info.Y2]})
return mdat1, mdat2, mdat_merged
@pytest.fixture
def merge_test_metadata(base_patch_metadata):
return juxt(*repeat(deepcopy, 3))(base_patch_metadata)
@pytest.fixture
def base_patch_image(stitch_test_pdf):
return pdf2image.convert_from_bytes(stitch_test_pdf)[0]
def test_concat_images_horizontally(horizontal_merge_test_metadata):
mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata
im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged])
im_merged = concat_images_horizontally(im1, im2, mdat_merged)
assert im_merged.size == im_merged_expected.size
assert images_equal(im_merged, im_merged_expected)
def test_concat_images_vertically(vertical_merge_test_metadata):
mdat1, mdat2, mdat_merged = vertical_merge_test_metadata
im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged])
im_merged = concat_images_vertically(im1, im2, mdat_merged)
assert im_merged.size == im_merged_expected.size
assert images_equal(im_merged, im_merged_expected)
@pytest.fixture
def stitch_test_pdf(patch_image_metadata_pairs, width, height):
pdf = fpdf.FPDF(unit="pt", format=(width, height))
for pair in patch_image_metadata_pairs:
add_image(pdf, pair)
return pdf.output(dest="S").encode("latin1")
@pytest.fixture
def patch_image_metadata_pairs(patches_metadata) -> List[ImageMetadataPair]:
images = map(random_single_color_image_from_metadata, patches_metadata)
return list(starmap(ImageMetadataPair, zip(images, patches_metadata)))
@pytest.fixture
def patches_metadata(base_patch_metadata, noise, split_count):
patches_metadata = list(BoxSplitter(noise).split_box(base_patch_metadata, split_count))
return patches_metadata
@pytest.fixture(params=[(0, 0)])
def noise(request):
return request.param
@pytest.fixture(params=[5])
def split_count(request):
return request.param