From 3266e0af584d067deaf2f659818d5649ea38e053 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Wed, 6 Apr 2022 15:41:42 +0200 Subject: [PATCH] refactoring; added metadata merging logic --- test/unit_tests/image_stitcher_test.py | 122 +++++++++++++++++++------ test/utils/stitching.py | 62 +++++++++---- 2 files changed, 138 insertions(+), 46 deletions(-) diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index 7e451ec..d53fc38 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -1,15 +1,17 @@ +from copy import deepcopy from functools import partial -from itertools import starmap, chain +from itertools import starmap, chain, repeat +from operator import itemgetter from typing import Iterable, List import fpdf import pytest -from funcy import merge, second, compose, rpartial +from funcy import merge, second, compose, rpartial, curry, juxt, project, omit from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info from test.conftest import get_base_position_metadata, add_image, random_single_color_image_from_metadata -from test.utils.stitching import BoxSplitter +from test.utils.stitching import BoxSplitter, VerticalKeyMapper, HorizontalKeyMapper from itertools import groupby @@ -29,42 +31,106 @@ def make_coord_getter(c): }[c] -def merge_group(group, axis="y"): - - y1_getter, y2_getter = map(make_coord_getter, ("y1", "y2")) - - group = list(group) - current_pair = group.pop(0) - for pair in group: - if y2_getter(current_pair) == y1_getter(pair): - current_box = merge_pair(current_pair, pair) +def make_length_getter(dim): + return { + "width": make_getter(Info.WIDTH), + "height": make_getter(Info.HEIGHT), + }[dim] -def merge_pair(p1, p2): - pass +x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2")) +width_getter, height_getter = map(make_length_getter, ("width", "height")) + + +# def merge_group(group, axis="y"): +# +# group = list(group) +# current_pair = group.pop(0) +# for pair in group: +# if y2_getter(current_pair) == y1_getter(pair): +# current_box = merge_pair(current_pair, pair) + + +def merge_metadata_horizontally(m1, m2): + m1, m2 = map(HorizontalKeyMapper, [m1, m2]) + return merge_metadata(m1, m2) + + +def merge_metadata_vertically(m1, m2): + m1, m2 = map(VerticalKeyMapper, [m1, m2]) + return merge_metadata(m1, m2) + + +def merge_metadata(m1, m2): + + c1 = min(m1.c1, m2.c1) + c2 = max(m1.c2, m2.c2) + dim = m1.dim + m2.dim + + merged = deepcopy(m1) + merged.dim = dim + merged.c1 = c1 + merged.c2 = c2 + + return merged.wrapped + + +def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair): + mdat_merged = merge_metadata_horizontally(p1.metadata, p2.metadata) + + +# def merge_pair(p1, p2): +# +# assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX] + + +def test_merge_metadata_horizontally(merge_test_metadata): + mdat1, mdat2, mdat_merged = merge_test_metadata + + mdat2[Info.X1] = mdat1[Info.X2] + mdat2[Info.X2] = mdat2[Info.X1] + mdat2[Info.WIDTH] + + mdat_merged.update({Info.WIDTH: mdat1[Info.WIDTH] + mdat2[Info.WIDTH], Info.X2: mdat2[Info.X2]}) + + assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged + + +def test_merge_metadata_vertically(merge_test_metadata): + mdat1, mdat2, mdat_merged = merge_test_metadata + + mdat2[Info.Y1] = mdat1[Info.Y2] + mdat2[Info.Y2] = mdat2[Info.Y1] + mdat2[Info.HEIGHT] + + mdat_merged.update({Info.HEIGHT: mdat1[Info.HEIGHT] + mdat2[Info.HEIGHT], Info.Y2: mdat2[Info.Y2]}) + + assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged + + +@pytest.fixture +def merge_test_metadata(base_patch_metadata): + return juxt(*repeat(deepcopy, 3))(base_patch_metadata) -class Stitcher: - - @staticmethod - def groupby(pairs, coord): - coord_getter = make_coord_getter(coord) - pairs = sorted(pairs, key=coord_getter) - return map(compose(list, second), groupby(pairs, coord_getter)) - - def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair: - groups = self.groupby(pairs, "x1") - groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups)) - groups = map(partial(sorted, key=make_coord_getter("y1")), groups) - groups = map(merge_group, groups) - +# class Stitcher: +# @staticmethod +# def groupby(pairs, coord): +# coord_getter = make_coord_getter(coord) +# pairs = sorted(pairs, key=coord_getter) +# return map(compose(list, second), groupby(pairs, coord_getter)) +# +# def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair: +# groups = self.groupby(pairs, "x1") +# groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups)) +# groups = map(partial(sorted, key=y1_getter), groups) +# groups = map(merge_group, groups) @pytest.mark.parametrize("width", [160]) @pytest.mark.parametrize("height", [90]) @pytest.mark.parametrize("page_width", [int(160 * 1.1)]) @pytest.mark.parametrize("page_height", [int(90 * 1.1)]) +@pytest.mark.skip() def test_image_stitcher(patches_metadata, base_patch_metadata): # noinspection PyTypeChecker assert Stitcher().stitch(patch_image_metadata_pairs).metadata == base_patch_metadata diff --git a/test/utils/stitching.py b/test/utils/stitching.py index 06b406a..6bdadbc 100644 --- a/test/utils/stitching.py +++ b/test/utils/stitching.py @@ -2,6 +2,7 @@ import abc import random from copy import deepcopy from itertools import chain +from operator import itemgetter from funcy import rpartial, juxt @@ -9,25 +10,50 @@ from image_prediction.info import Info class SplitKeyMapper(abc.ABC): - def __init__(self, box: dict, keymap: dict): - self.box = box + + def __init__(self, wrapped: dict, keymap: dict): + self.wrapped = wrapped self.keymap = keymap - def __getitem__(self, item): - return self.box[self.keymap[item]] + def __getitem__(self, key): + return self.wrapped[self.keymap[key]] def __setitem__(self, key, value): - self.box[self.keymap[key]] = value + self.wrapped[self.keymap[key]] = value + + @property + def dim(self): + return self.__getitem__("dim") + + @dim.setter + def dim(self, value): + self.__setitem__("dim", value) + + @property + def c1(self): + return self.__getitem__("c1") + + @c1.setter + def c1(self, value): + self.__setitem__("c1", value) + + @property + def c2(self): + return self.__getitem__("c2") + + @c2.setter + def c2(self, value): + self.__setitem__("c2", value) class HorizontalKeyMapper(SplitKeyMapper): - def __init__(self, box: dict): - super().__init__(box, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2}) + def __init__(self, wrapped: dict): + super().__init__(wrapped, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2}) class VerticalKeyMapper(SplitKeyMapper): - def __init__(self, box: dict): - super().__init__(box, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2}) + def __init__(self, wrapped: dict): + super().__init__(wrapped, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2}) class BoxSplitter: @@ -71,25 +97,25 @@ class BoxSplitter: return ( self.__get_child_boxes(wrapped_box) if self.__large_enough(wrapped_box) - else self.__base_case(wrapped_box.box) + else self.__base_case(wrapped_box.wrapped) ) @staticmethod def __large_enough(wrapped_box: SplitKeyMapper): - return wrapped_box["dim"] >= 10 + return wrapped_box.dim >= 10 @staticmethod def __get_child_boxes(wrapped_box: SplitKeyMapper): - split_len = random.randint(5, wrapped_box["dim"] - 5) - split_point = wrapped_box["c1"] + split_len + split_len = random.randint(5, wrapped_box.dim - 5) + split_point = wrapped_box.c1 + split_len box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box) - box_left["dim"] = split_len - box_right["dim"] = wrapped_box["dim"] - split_len + box_left.dim = split_len + box_right.dim = wrapped_box.dim - split_len - box_left["c2"] = split_point - box_right["c1"] = split_point + box_left.c2 = split_point + box_right.c1 = split_point - return box_left.box, box_right.box + return box_left.wrapped, box_right.wrapped