diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index e383b10..7e451ec 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -1,17 +1,73 @@ -from itertools import starmap +from functools import partial +from itertools import starmap, chain +from typing import Iterable, List import fpdf import pytest -from funcy import merge +from funcy import merge, second, compose, rpartial from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info from test.conftest import get_base_position_metadata, add_image, random_single_color_image_from_metadata from test.utils.stitching import BoxSplitter +from itertools import groupby -# def test_image_stitcher(patches_metadata): -# assert Stitcher().stitch(patches) +def make_getter(key): + def getter(pair): + return pair.metadata[key] + + return getter + + +def make_coord_getter(c): + return { + "x1": make_getter(Info.X1), + "x2": make_getter(Info.X2), + "y1": make_getter(Info.Y1), + "y2": make_getter(Info.Y2), + }[c] + + +def merge_group(group, axis="y"): + + y1_getter, y2_getter = map(make_coord_getter, ("y1", "y2")) + + group = list(group) + current_pair = group.pop(0) + for pair in group: + if y2_getter(current_pair) == y1_getter(pair): + current_box = merge_pair(current_pair, pair) + + +def merge_pair(p1, p2): + pass + + + +class Stitcher: + + @staticmethod + def groupby(pairs, coord): + coord_getter = make_coord_getter(coord) + pairs = sorted(pairs, key=coord_getter) + return map(compose(list, second), groupby(pairs, coord_getter)) + + def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair: + groups = self.groupby(pairs, "x1") + groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups)) + groups = map(partial(sorted, key=make_coord_getter("y1")), groups) + groups = map(merge_group, groups) + + + +@pytest.mark.parametrize("width", [160]) +@pytest.mark.parametrize("height", [90]) +@pytest.mark.parametrize("page_width", [int(160 * 1.1)]) +@pytest.mark.parametrize("page_height", [int(90 * 1.1)]) +def test_image_stitcher(patches_metadata, base_patch_metadata): + # noinspection PyTypeChecker + assert Stitcher().stitch(patch_image_metadata_pairs).metadata == base_patch_metadata @pytest.mark.parametrize("width", [160]) @@ -29,9 +85,9 @@ def test_partial_image_metadata_pairs(patch_image_metadata_pairs, page_width, pa @pytest.fixture -def patch_image_metadata_pairs(patches_metadata): +def patch_image_metadata_pairs(patches_metadata) -> List[ImageMetadataPair]: images = map(random_single_color_image_from_metadata, patches_metadata) - return starmap(ImageMetadataPair, zip(images, patches_metadata)) + return list(starmap(ImageMetadataPair, zip(images, patches_metadata))) @pytest.fixture