import json import os from copy import deepcopy from functools import partial from itertools import starmap, repeat from operator import itemgetter from typing import List import fpdf import pdf2image import pytest from funcy import juxt, one, first from image_prediction.formatter.formatters.enum import ReverseEnumFormatter from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info from image_prediction.stitching.grouping import group_by_coordinate from image_prediction.stitching.merging import ( merge_metadata_horizontally, merge_metadata_vertically, merge_pair_horizontally, merge_pair_vertically, concat_images_horizontally, concat_images_vertically, merge_group_horizontally, merge_group_vertically, ) from image_prediction.stitching.stitching import stitch_pairs from image_prediction.stitching.utils import ( make_coord_getter, make_length_getter, ) from test.utils.comparison import images_equal from test.utils.generation.image import random_single_color_image_from_metadata, gray_image_from_metadata from test.utils.generation.pdf import add_image from test.utils.stitching import BoxSplitter x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2")) width_getter, height_getter = map(make_length_getter, ("width", "height")) def test_group_by_coordinate_exact(): pairs = [(0, 1), (0, 3), (1, 4), (1, 4), (1, 2), (3, 3)] pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=0)) assert pairs_grouped == [[(0, 1), (0, 3)], [(1, 4), (1, 4), (1, 2)], [(3, 3)]] def test_group_by_coordinate_fuzzy(): pairs = [(0, 1), (1, 3), (1, 4), (2, 4), (2, 2), (3, 3)] pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=1)) assert pairs_grouped == [[(0, 1), (1, 3), (1, 4)], [(2, 4), (2, 2), (3, 3)]] def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image): pairs_stitched = stitch_pairs(patch_image_metadata_pairs) pair_stitched = first(pairs_stitched) assert len(pairs_stitched) == 1 assert pair_stitched.metadata == base_patch_metadata assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4) def test_image_stitcher_with_gaps_must_succeed(): from image_prediction.locations import TEST_DATA_DIR with open(os.path.join(TEST_DATA_DIR, "stitching_with_tolerance.json")) as f: patches_metadata, base_patch_metadata = itemgetter("input", "target")(ReverseEnumFormatter(Info)(json.load(f))) images = map(gray_image_from_metadata, patches_metadata) patch_image_metadata_pairs = list(starmap(ImageMetadataPair, zip(images, patches_metadata))) pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=7) assert len(pairs_stitched) == 1 pair_stitched = first(pairs_stitched) assert pair_stitched.metadata == base_patch_metadata @pytest.mark.parametrize("noise", [(0, 2)]) @pytest.mark.parametrize("split_count", [5]) @pytest.mark.parametrize("width", [100]) @pytest.mark.parametrize("height", [100]) @pytest.mark.parametrize("page_width", [100]) @pytest.mark.parametrize("page_height", [100]) @pytest.mark.parametrize("execution_number", range(100)) @pytest.mark.xfail(reason="Does not always succeed due to locally maximizing merging logic.") def test_image_stitcher_with_gaps_can_fail(patch_image_metadata_pairs, base_patch_metadata, execution_number): pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=4) assert len(pairs_stitched) == 1 and first(pairs_stitched).metadata == base_patch_metadata def test_merge_group_horizontally(horizontal_merge_test_pairs): pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs prs_merged = merge_group_horizontally([pr1, pr2]) assert len(prs_merged) == 1 assert pair_equal(prs_merged[0], pr_merged_expected) mdat3 = deepcopy(pr2.metadata) mdat3[Info.HEIGHT] += 30 mdat3[Info.Y2] += 30 im3 = gray_image_from_metadata(mdat3) pr3 = ImageMetadataPair(im3, mdat3) prs_merged = merge_group_horizontally([pr1, pr2, pr3]) assert len(prs_merged) == 2 assert one(partial(pair_equal, pr_merged_expected), prs_merged) def test_merge_group_vertically(vertical_merge_test_pairs): pr1, pr2, pr_merged_expected = vertical_merge_test_pairs prs_merged = merge_group_vertically([pr1, pr2]) assert len(prs_merged) == 1 assert pair_equal(prs_merged[0], pr_merged_expected) mdat3 = deepcopy(pr2.metadata) mdat3[Info.WIDTH] += 30 mdat3[Info.X2] += 30 im3 = gray_image_from_metadata(mdat3) pr3 = ImageMetadataPair(im3, mdat3) prs_merged = merge_group_vertically([pr1, pr2, pr3]) assert len(prs_merged) == 2 assert one(partial(pair_equal, pr_merged_expected), prs_merged) def pair_equal(pr1, pr2): return pr1.metadata == pr2.metadata and images_equal(pr1.image, pr2.image) def test_merge_pairs_horizontally(horizontal_merge_test_pairs): pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs pr_merged = merge_pair_horizontally(pr1, pr2) assert pair_equal(pr_merged, pr_merged_expected) def test_merge_pairs_vertically(vertical_merge_test_pairs): pr1, pr2, pr_merged_expected = vertical_merge_test_pairs pr_merged = merge_pair_vertically(pr1, pr2) assert pair_equal(pr_merged, pr_merged_expected) @pytest.fixture def horizontal_merge_test_pairs(horizontal_merge_test_metadata): images = map(gray_image_from_metadata, horizontal_merge_test_metadata) return list(starmap(ImageMetadataPair, zip(images, horizontal_merge_test_metadata))) @pytest.fixture def vertical_merge_test_pairs(vertical_merge_test_metadata): images = map(gray_image_from_metadata, vertical_merge_test_metadata) return list(starmap(ImageMetadataPair, zip(images, vertical_merge_test_metadata))) def test_merge_metadata_horizontally(horizontal_merge_test_metadata): mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged def test_merge_metadata_vertically(vertical_merge_test_metadata): mdat1, mdat2, mdat_merged = vertical_merge_test_metadata assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged @pytest.fixture def horizontal_merge_test_metadata(merge_test_metadata): mdat1, mdat2, mdat_merged = merge_test_metadata mdat2[Info.X1] = mdat1[Info.X2] mdat2[Info.X2] = mdat2[Info.X1] + mdat2[Info.WIDTH] mdat_merged.update({Info.WIDTH: mdat1[Info.WIDTH] + mdat2[Info.WIDTH], Info.X2: mdat2[Info.X2]}) return mdat1, mdat2, mdat_merged @pytest.fixture def vertical_merge_test_metadata(merge_test_metadata): mdat1, mdat2, mdat_merged = merge_test_metadata mdat2[Info.Y1] = mdat1[Info.Y2] mdat2[Info.Y2] = mdat2[Info.Y1] + mdat2[Info.HEIGHT] mdat_merged.update({Info.HEIGHT: mdat1[Info.HEIGHT] + mdat2[Info.HEIGHT], Info.Y2: mdat2[Info.Y2]}) return mdat1, mdat2, mdat_merged @pytest.fixture def merge_test_metadata(base_patch_metadata): return juxt(*repeat(deepcopy, 3))(base_patch_metadata) @pytest.fixture def base_patch_image(stitch_test_pdf): return pdf2image.convert_from_bytes(stitch_test_pdf)[0] def test_concat_images_horizontally(horizontal_merge_test_metadata): mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata im1, im2, im_merged_expected = map(gray_image_from_metadata, [mdat1, mdat2, mdat_merged]) im_merged = concat_images_horizontally(im1, im2, mdat_merged) assert im_merged.size == im_merged_expected.size assert images_equal(im_merged, im_merged_expected) def test_concat_images_vertically(vertical_merge_test_metadata): mdat1, mdat2, mdat_merged = vertical_merge_test_metadata im1, im2, im_merged_expected = map(gray_image_from_metadata, [mdat1, mdat2, mdat_merged]) im_merged = concat_images_vertically(im1, im2, mdat_merged) assert im_merged.size == im_merged_expected.size assert images_equal(im_merged, im_merged_expected) @pytest.fixture def stitch_test_pdf(patch_image_metadata_pairs, width, height): pdf = fpdf.FPDF(unit="pt", format=(width, height)) for pair in patch_image_metadata_pairs: add_image(pdf, pair) return pdf.output(dest="S").encode("latin1") @pytest.fixture def patch_image_metadata_pairs(patches_metadata) -> List[ImageMetadataPair]: images = map(random_single_color_image_from_metadata, patches_metadata) return list(starmap(ImageMetadataPair, zip(images, patches_metadata))) @pytest.fixture def patches_metadata(base_patch_metadata, noise, split_count): patches_metadata = list(BoxSplitter(noise).split_box(base_patch_metadata, split_count)) return patches_metadata @pytest.fixture(params=[(0, 0)]) def noise(request): return request.param @pytest.fixture(params=[5]) def split_count(request): return request.param