refactoring; added metadata merging logic

This commit is contained in:
Matthias Bisping 2022-04-06 15:41:42 +02:00
parent 7e2696d5c5
commit 3266e0af58
2 changed files with 138 additions and 46 deletions

View File

@ -1,15 +1,17 @@
from copy import deepcopy
from functools import partial
from itertools import starmap, chain
from itertools import starmap, chain, repeat
from operator import itemgetter
from typing import Iterable, List
import fpdf
import pytest
from funcy import merge, second, compose, rpartial
from funcy import merge, second, compose, rpartial, curry, juxt, project, omit
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from test.conftest import get_base_position_metadata, add_image, random_single_color_image_from_metadata
from test.utils.stitching import BoxSplitter
from test.utils.stitching import BoxSplitter, VerticalKeyMapper, HorizontalKeyMapper
from itertools import groupby
@ -29,42 +31,106 @@ def make_coord_getter(c):
}[c]
def merge_group(group, axis="y"):
y1_getter, y2_getter = map(make_coord_getter, ("y1", "y2"))
group = list(group)
current_pair = group.pop(0)
for pair in group:
if y2_getter(current_pair) == y1_getter(pair):
current_box = merge_pair(current_pair, pair)
def make_length_getter(dim):
return {
"width": make_getter(Info.WIDTH),
"height": make_getter(Info.HEIGHT),
}[dim]
def merge_pair(p1, p2):
pass
x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2"))
width_getter, height_getter = map(make_length_getter, ("width", "height"))
# def merge_group(group, axis="y"):
#
# group = list(group)
# current_pair = group.pop(0)
# for pair in group:
# if y2_getter(current_pair) == y1_getter(pair):
# current_box = merge_pair(current_pair, pair)
def merge_metadata_horizontally(m1, m2):
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata_vertically(m1, m2):
m1, m2 = map(VerticalKeyMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata(m1, m2):
c1 = min(m1.c1, m2.c1)
c2 = max(m1.c2, m2.c2)
dim = m1.dim + m2.dim
merged = deepcopy(m1)
merged.dim = dim
merged.c1 = c1
merged.c2 = c2
return merged.wrapped
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
mdat_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
# def merge_pair(p1, p2):
#
# assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX]
def test_merge_metadata_horizontally(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.X1] = mdat1[Info.X2]
mdat2[Info.X2] = mdat2[Info.X1] + mdat2[Info.WIDTH]
mdat_merged.update({Info.WIDTH: mdat1[Info.WIDTH] + mdat2[Info.WIDTH], Info.X2: mdat2[Info.X2]})
assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged
def test_merge_metadata_vertically(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.Y1] = mdat1[Info.Y2]
mdat2[Info.Y2] = mdat2[Info.Y1] + mdat2[Info.HEIGHT]
mdat_merged.update({Info.HEIGHT: mdat1[Info.HEIGHT] + mdat2[Info.HEIGHT], Info.Y2: mdat2[Info.Y2]})
assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged
@pytest.fixture
def merge_test_metadata(base_patch_metadata):
return juxt(*repeat(deepcopy, 3))(base_patch_metadata)
class Stitcher:
@staticmethod
def groupby(pairs, coord):
coord_getter = make_coord_getter(coord)
pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter))
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
groups = self.groupby(pairs, "x1")
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
groups = map(partial(sorted, key=make_coord_getter("y1")), groups)
groups = map(merge_group, groups)
# class Stitcher:
# @staticmethod
# def groupby(pairs, coord):
# coord_getter = make_coord_getter(coord)
# pairs = sorted(pairs, key=coord_getter)
# return map(compose(list, second), groupby(pairs, coord_getter))
#
# def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
# groups = self.groupby(pairs, "x1")
# groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
# groups = map(partial(sorted, key=y1_getter), groups)
# groups = map(merge_group, groups)
@pytest.mark.parametrize("width", [160])
@pytest.mark.parametrize("height", [90])
@pytest.mark.parametrize("page_width", [int(160 * 1.1)])
@pytest.mark.parametrize("page_height", [int(90 * 1.1)])
@pytest.mark.skip()
def test_image_stitcher(patches_metadata, base_patch_metadata):
# noinspection PyTypeChecker
assert Stitcher().stitch(patch_image_metadata_pairs).metadata == base_patch_metadata

View File

@ -2,6 +2,7 @@ import abc
import random
from copy import deepcopy
from itertools import chain
from operator import itemgetter
from funcy import rpartial, juxt
@ -9,25 +10,50 @@ from image_prediction.info import Info
class SplitKeyMapper(abc.ABC):
def __init__(self, box: dict, keymap: dict):
self.box = box
def __init__(self, wrapped: dict, keymap: dict):
self.wrapped = wrapped
self.keymap = keymap
def __getitem__(self, item):
return self.box[self.keymap[item]]
def __getitem__(self, key):
return self.wrapped[self.keymap[key]]
def __setitem__(self, key, value):
self.box[self.keymap[key]] = value
self.wrapped[self.keymap[key]] = value
@property
def dim(self):
return self.__getitem__("dim")
@dim.setter
def dim(self, value):
self.__setitem__("dim", value)
@property
def c1(self):
return self.__getitem__("c1")
@c1.setter
def c1(self, value):
self.__setitem__("c1", value)
@property
def c2(self):
return self.__getitem__("c2")
@c2.setter
def c2(self, value):
self.__setitem__("c2", value)
class HorizontalKeyMapper(SplitKeyMapper):
def __init__(self, box: dict):
super().__init__(box, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2})
def __init__(self, wrapped: dict):
super().__init__(wrapped, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2})
class VerticalKeyMapper(SplitKeyMapper):
def __init__(self, box: dict):
super().__init__(box, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2})
def __init__(self, wrapped: dict):
super().__init__(wrapped, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2})
class BoxSplitter:
@ -71,25 +97,25 @@ class BoxSplitter:
return (
self.__get_child_boxes(wrapped_box)
if self.__large_enough(wrapped_box)
else self.__base_case(wrapped_box.box)
else self.__base_case(wrapped_box.wrapped)
)
@staticmethod
def __large_enough(wrapped_box: SplitKeyMapper):
return wrapped_box["dim"] >= 10
return wrapped_box.dim >= 10
@staticmethod
def __get_child_boxes(wrapped_box: SplitKeyMapper):
split_len = random.randint(5, wrapped_box["dim"] - 5)
split_point = wrapped_box["c1"] + split_len
split_len = random.randint(5, wrapped_box.dim - 5)
split_point = wrapped_box.c1 + split_len
box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box)
box_left["dim"] = split_len
box_right["dim"] = wrapped_box["dim"] - split_len
box_left.dim = split_len
box_right.dim = wrapped_box.dim - split_len
box_left["c2"] = split_point
box_right["c1"] = split_point
box_left.c2 = split_point
box_right.c1 = split_point
return box_left.box, box_right.box
return box_left.wrapped, box_right.wrapped