from copy import deepcopy from functools import partial, reduce from itertools import groupby from itertools import starmap, chain, repeat from typing import Iterable, List import fpdf import numpy as np import pdf2image import pytest from PIL import Image from funcy import merge, second, compose, rpartial, juxt, rest, first, one, iterate from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info from image_prediction.utils import chunk_iterable from test.conftest import ( get_base_position_metadata, add_image, random_single_color_image_from_metadata, random_size_gray_image_from_metadata, ) from test.utils.stitching import BoxSplitter, VerticalKeyMapper, HorizontalKeyMapper def make_getter(key): def getter(pair): return pair.metadata[key] return getter def make_coord_getter(c): return { "x1": make_getter(Info.X1), "x2": make_getter(Info.X2), "y1": make_getter(Info.Y1), "y2": make_getter(Info.Y2), }[c] def make_pair_merger(direction): return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[direction] def make_length_getter(dim): return { "width": make_getter(Info.WIDTH), "height": make_getter(Info.HEIGHT), }[dim] x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2")) width_getter, height_getter = map(make_length_getter, ("width", "height")) def merge_metadata_horizontally(m1, m2): m1, m2 = map(HorizontalKeyMapper, [m1, m2]) return merge_metadata(m1, m2) def merge_metadata_vertically(m1, m2): m1, m2 = map(VerticalKeyMapper, [m1, m2]) return merge_metadata(m1, m2) def merge_metadata(m1, m2): c1 = min(m1.c1, m2.c1) c2 = max(m1.c2, m2.c2) dim = m1.dim + m2.dim merged = deepcopy(m1) merged.dim = dim merged.c1 = c1 merged.c2 = c2 return merged.wrapped def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair): metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata) image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged) return ImageMetadataPair(image_concatenated, metadata_merged) def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair): metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata) image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged) return ImageMetadataPair(image_concatenated, metadata_merged) # def merge_pair(p1, p2): # assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX] def concat_images_horizontally(im1: Image, im2: Image, metadata: dict): return concat_images(im1, im2, metadata, 0) def concat_images_vertically(im1: Image, im2: Image, metadata: dict): return concat_images(im1, im2, metadata, 1) def concat_images(im1: Image, im2: Image, metadata: dict, axis): im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT])) images = [im1, im2] offsets = [0, *[im.size[axis] for im in images]] for im, offset in zip(images, offsets): box = (offset, 0) if not axis else (0, offset) im_aggr.paste(im, box=box) return im_aggr def make_merger_aggregator(direction): """Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the head H and aggregates non-adjacent in the tail T. """ def merger_aggregator(pairs): def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair): """Keeps the image that is being merged with as the head and aggregates non-mergables in the tail.""" aggr, non_aggr = juxt(first, rest)(pairs_aggr) if c2_getter(aggr) == c1_getter(pair): aggr = pair_merger(aggr, pair) return aggr, *non_aggr else: return aggr, pair, *non_aggr # Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens # only in c1 -> c2 direction. pairs = sorted(pairs, key=c1_getter) aggregation_pair, pairs = juxt(first, rest)(pairs) return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair])) c1_getter = make_coord_getter(f"{direction}1") c2_getter = make_coord_getter(f"{direction}2") pair_merger = make_pair_merger(direction) return merger_aggregator def merge_group(group, direction): reduce_group = make_merger_aggregator(direction) for g1, g2 in chunk_iterable(iterate(reduce_group, group), chunk_size=2): if len(g1) == len(g2): return g1 def merge_group_horizontally(group): return merge_group(group, "x") def merge_group_vertically(group): return merge_group(group, "y") class Stitcher: @staticmethod def groupby(pairs, coord): coord_getter = make_coord_getter(coord) pairs = sorted(pairs, key=coord_getter) return map(compose(list, second), groupby(pairs, coord_getter)) def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]: pairs = list(pairs) n = len(pairs) while True: groups = self.groupby(pairs, "x1") groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups)) groups = map(merge_group_vertically, groups) pairs = chain.from_iterable(groups) groups = self.groupby(pairs, "y1") groups = chain.from_iterable(map(rpartial(self.groupby, "y2"), groups)) groups = map(merge_group_horizontally, groups) pairs = list(chain.from_iterable(groups)) if len(pairs) == n: return pairs n = len(pairs) ##################################### # @pytest.mark.parametrize("width", [160]) # @pytest.mark.parametrize("height", [90]) # @pytest.mark.parametrize("page_width", [int(160 * 1.1)]) # @pytest.mark.parametrize("page_height", [int(90 * 1.1)]) def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image): pair_stitched = Stitcher().stitch(patch_image_metadata_pairs)[0] assert pair_stitched.metadata == base_patch_metadata # pair_stitched.image.show() # base_patch_image.show() assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4) def test_merge_group_horizontally(horizontal_merge_test_pairs): pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs prs_merged = merge_group_horizontally([pr1, pr2]) assert len(prs_merged) == 1 assert pair_equal(prs_merged[0], pr_merged_expected) mdat3 = deepcopy(pr2.metadata) mdat3[Info.HEIGHT] += 30 mdat3[Info.Y2] += 30 im3 = random_size_gray_image_from_metadata(mdat3) pr3 = ImageMetadataPair(im3, mdat3) prs_merged = merge_group_horizontally([pr1, pr2, pr3]) assert len(prs_merged) == 2 assert one(partial(pair_equal, pr_merged_expected), prs_merged) def test_merge_group_vertically(vertical_merge_test_pairs): pr1, pr2, pr_merged_expected = vertical_merge_test_pairs prs_merged = merge_group_vertically([pr1, pr2]) assert len(prs_merged) == 1 assert pair_equal(prs_merged[0], pr_merged_expected) mdat3 = deepcopy(pr2.metadata) mdat3[Info.WIDTH] += 30 mdat3[Info.X2] += 30 im3 = random_size_gray_image_from_metadata(mdat3) pr3 = ImageMetadataPair(im3, mdat3) prs_merged = merge_group_vertically([pr1, pr2, pr3]) assert len(prs_merged) == 2 assert one(partial(pair_equal, pr_merged_expected), prs_merged) def pair_equal(pr1, pr2): return pr1.metadata == pr2.metadata and images_equal(pr1.image, pr2.image) def test_merge_pairs_horizontally(horizontal_merge_test_pairs): pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs pr_merged = merge_pair_horizontally(pr1, pr2) assert pair_equal(pr_merged, pr_merged_expected) def test_merge_pairs_vertically(vertical_merge_test_pairs): pr1, pr2, pr_merged_expected = vertical_merge_test_pairs pr_merged = merge_pair_vertically(pr1, pr2) assert pair_equal(pr_merged, pr_merged_expected) def images_equal(im1: Image, im2: Image, **kwargs): return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs) @pytest.fixture def horizontal_merge_test_pairs(horizontal_merge_test_metadata): images = map(random_size_gray_image_from_metadata, horizontal_merge_test_metadata) return list(starmap(ImageMetadataPair, zip(images, horizontal_merge_test_metadata))) @pytest.fixture def vertical_merge_test_pairs(vertical_merge_test_metadata): images = map(random_size_gray_image_from_metadata, vertical_merge_test_metadata) return list(starmap(ImageMetadataPair, zip(images, vertical_merge_test_metadata))) def test_merge_metadata_horizontally(horizontal_merge_test_metadata): mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged def test_merge_metadata_vertically(vertical_merge_test_metadata): mdat1, mdat2, mdat_merged = vertical_merge_test_metadata assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged @pytest.fixture def horizontal_merge_test_metadata(merge_test_metadata): mdat1, mdat2, mdat_merged = merge_test_metadata mdat2[Info.X1] = mdat1[Info.X2] mdat2[Info.X2] = mdat2[Info.X1] + mdat2[Info.WIDTH] mdat_merged.update({Info.WIDTH: mdat1[Info.WIDTH] + mdat2[Info.WIDTH], Info.X2: mdat2[Info.X2]}) return mdat1, mdat2, mdat_merged @pytest.fixture def vertical_merge_test_metadata(merge_test_metadata): mdat1, mdat2, mdat_merged = merge_test_metadata mdat2[Info.Y1] = mdat1[Info.Y2] mdat2[Info.Y2] = mdat2[Info.Y1] + mdat2[Info.HEIGHT] mdat_merged.update({Info.HEIGHT: mdat1[Info.HEIGHT] + mdat2[Info.HEIGHT], Info.Y2: mdat2[Info.Y2]}) return mdat1, mdat2, mdat_merged @pytest.fixture def merge_test_metadata(base_patch_metadata): return juxt(*repeat(deepcopy, 3))(base_patch_metadata) @pytest.fixture def base_patch_image(stitch_test_pdf): return pdf2image.convert_from_bytes(stitch_test_pdf)[0] def test_concat_images_horizontally(horizontal_merge_test_metadata): mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged]) im_merged = concat_images_horizontally(im1, im2, mdat_merged) assert im_merged.size == im_merged_expected.size assert images_equal(im_merged, im_merged_expected) def test_concat_images_vertically(vertical_merge_test_metadata): mdat1, mdat2, mdat_merged = vertical_merge_test_metadata im1, im2, im_merged_expected = map(random_size_gray_image_from_metadata, [mdat1, mdat2, mdat_merged]) im_merged = concat_images_vertically(im1, im2, mdat_merged) assert im_merged.size == im_merged_expected.size assert images_equal(im_merged, im_merged_expected) @pytest.fixture def stitch_test_pdf(patch_image_metadata_pairs, width, height): pdf = fpdf.FPDF(unit="pt", format=(width, height)) for pair in patch_image_metadata_pairs: add_image(pdf, pair) return pdf.output(dest="S").encode("latin1") @pytest.fixture def patch_image_metadata_pairs(patches_metadata) -> List[ImageMetadataPair]: images = map(random_single_color_image_from_metadata, patches_metadata) return list(starmap(ImageMetadataPair, zip(images, patches_metadata))) @pytest.fixture def base_patch_metadata(width, height, page_width, page_height): metadata = get_base_position_metadata(width, height, page_width, page_height) metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height}) return metadata @pytest.fixture def patches_metadata(base_patch_metadata): patches_metadata = list(BoxSplitter().split_box(base_patch_metadata)) return patches_metadata