diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index d53fc38..c8c7e45 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -1,12 +1,13 @@ from copy import deepcopy from functools import partial from itertools import starmap, chain, repeat -from operator import itemgetter +from operator import itemgetter, attrgetter from typing import Iterable, List import fpdf import pytest -from funcy import merge, second, compose, rpartial, curry, juxt, project, omit +from PIL import Image +from funcy import merge, second, compose, rpartial, curry, juxt, project, omit, flip, identity from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info @@ -42,13 +43,13 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", width_getter, height_getter = map(make_length_getter, ("width", "height")) -# def merge_group(group, axis="y"): -# -# group = list(group) -# current_pair = group.pop(0) -# for pair in group: -# if y2_getter(current_pair) == y1_getter(pair): -# current_box = merge_pair(current_pair, pair) +def merge_group(group, axis="y"): + + group = list(group) + current_pair = group.pop(0) + for pair in group: + if y2_getter(current_pair) == y1_getter(pair): + current_box = merge_pair_vertically(current_pair, pair) def merge_metadata_horizontally(m1, m2): @@ -79,12 +80,57 @@ def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair): mdat_merged = merge_metadata_horizontally(p1.metadata, p2.metadata) -# def merge_pair(p1, p2): -# -# assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX] +def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair): + mdat_merged = merge_metadata_vertically(p1.metadata, p2.metadata) -def test_merge_metadata_horizontally(merge_test_metadata): +def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair): + pass + + +def merge_pair(p1, p2): + assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX] + + +def concat_images_horizontally(im1: Image, im2: Image, metadata: dict): + return concat_images(im1, im2, metadata, 0) + + +def concat_images_vertically(im1: Image, im2: Image, metadata: dict): + return concat_images(im1, im2, metadata, 1) + + +def concat_images(im1: Image, im2: Image, metadata: dict, axis): + + im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT])) + + images = [im1, im2] + + offsets = [0, *[im.size[axis] for im in images]] + + for im, offset in zip(images, offsets): + box = (offset, 0) if axis else (0, offset) + im_aggr.paste(im, box=box) + + return im_aggr + + +##################################### + + +def test_merge_metadata_horizontally(horizontal_merge_test_metadata): + mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata + assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged + + +def test_merge_metadata_vertically(vertical_merge_test_metadata): + mdat1, mdat2, mdat_merged = vertical_merge_test_metadata + assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged + + +@pytest.fixture +def horizontal_merge_test_metadata(merge_test_metadata): + mdat1, mdat2, mdat_merged = merge_test_metadata mdat2[Info.X1] = mdat1[Info.X2] @@ -92,10 +138,12 @@ def test_merge_metadata_horizontally(merge_test_metadata): mdat_merged.update({Info.WIDTH: mdat1[Info.WIDTH] + mdat2[Info.WIDTH], Info.X2: mdat2[Info.X2]}) - assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged + return mdat1, mdat2, mdat_merged -def test_merge_metadata_vertically(merge_test_metadata): +@pytest.fixture +def vertical_merge_test_metadata(merge_test_metadata): + mdat1, mdat2, mdat_merged = merge_test_metadata mdat2[Info.Y1] = mdat1[Info.Y2] @@ -103,27 +151,43 @@ def test_merge_metadata_vertically(merge_test_metadata): mdat_merged.update({Info.HEIGHT: mdat1[Info.HEIGHT] + mdat2[Info.HEIGHT], Info.Y2: mdat2[Info.Y2]}) - assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged + return mdat1, mdat2, mdat_merged @pytest.fixture def merge_test_metadata(base_patch_metadata): - return juxt(*repeat(deepcopy, 3))(base_patch_metadata) + return juxt(*repeat(deepcopy, 3))(base_patch_metadata) +@pytest.fixture +def merge_test_image_metadata_pairs(): + pass -# class Stitcher: -# @staticmethod -# def groupby(pairs, coord): -# coord_getter = make_coord_getter(coord) -# pairs = sorted(pairs, key=coord_getter) -# return map(compose(list, second), groupby(pairs, coord_getter)) -# -# def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair: -# groups = self.groupby(pairs, "x1") -# groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups)) -# groups = map(partial(sorted, key=y1_getter), groups) -# groups = map(merge_group, groups) + +def test_concat_images_horizontally(horizontal_merge_test_metadata): + mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata + im1, im2, im_merged = map(random_single_color_image_from_metadata, [mdat1, mdat2, mdat_merged]) + assert concat_images_horizontally(im1, im2, mdat_merged).size == im_merged.size + + +def test_concat_images_vertically(vertical_merge_test_metadata): + mdat1, mdat2, mdat_merged = vertical_merge_test_metadata + im1, im2, im_merged = map(random_single_color_image_from_metadata, [mdat1, mdat2, mdat_merged]) + assert concat_images_vertically(im1, im2, mdat_merged).size == im_merged.size + + +class Stitcher: + @staticmethod + def groupby(pairs, coord): + coord_getter = make_coord_getter(coord) + pairs = sorted(pairs, key=coord_getter) + return map(compose(list, second), groupby(pairs, coord_getter)) + + def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair: + groups = self.groupby(pairs, "x1") + groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups)) + groups = map(partial(sorted, key=y1_getter), groups) + groups = map(merge_group, groups) @pytest.mark.parametrize("width", [160])