refactoring; added metadata merging logic

This commit is contained in:
Matthias Bisping 2022-04-06 15:41:42 +02:00
parent 7e2696d5c5
commit 3266e0af58
2 changed files with 138 additions and 46 deletions

View File

@ -1,15 +1,17 @@
from copy import deepcopy
from functools import partial from functools import partial
from itertools import starmap, chain from itertools import starmap, chain, repeat
from operator import itemgetter
from typing import Iterable, List from typing import Iterable, List
import fpdf import fpdf
import pytest import pytest
from funcy import merge, second, compose, rpartial from funcy import merge, second, compose, rpartial, curry, juxt, project, omit
from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info from image_prediction.info import Info
from test.conftest import get_base_position_metadata, add_image, random_single_color_image_from_metadata from test.conftest import get_base_position_metadata, add_image, random_single_color_image_from_metadata
from test.utils.stitching import BoxSplitter from test.utils.stitching import BoxSplitter, VerticalKeyMapper, HorizontalKeyMapper
from itertools import groupby from itertools import groupby
@ -29,42 +31,106 @@ def make_coord_getter(c):
}[c] }[c]
def merge_group(group, axis="y"): def make_length_getter(dim):
return {
y1_getter, y2_getter = map(make_coord_getter, ("y1", "y2")) "width": make_getter(Info.WIDTH),
"height": make_getter(Info.HEIGHT),
group = list(group) }[dim]
current_pair = group.pop(0)
for pair in group:
if y2_getter(current_pair) == y1_getter(pair):
current_box = merge_pair(current_pair, pair)
def merge_pair(p1, p2): x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2"))
pass width_getter, height_getter = map(make_length_getter, ("width", "height"))
# def merge_group(group, axis="y"):
#
# group = list(group)
# current_pair = group.pop(0)
# for pair in group:
# if y2_getter(current_pair) == y1_getter(pair):
# current_box = merge_pair(current_pair, pair)
def merge_metadata_horizontally(m1, m2):
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata_vertically(m1, m2):
m1, m2 = map(VerticalKeyMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata(m1, m2):
c1 = min(m1.c1, m2.c1)
c2 = max(m1.c2, m2.c2)
dim = m1.dim + m2.dim
merged = deepcopy(m1)
merged.dim = dim
merged.c1 = c1
merged.c2 = c2
return merged.wrapped
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
mdat_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
# def merge_pair(p1, p2):
#
# assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX]
def test_merge_metadata_horizontally(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.X1] = mdat1[Info.X2]
mdat2[Info.X2] = mdat2[Info.X1] + mdat2[Info.WIDTH]
mdat_merged.update({Info.WIDTH: mdat1[Info.WIDTH] + mdat2[Info.WIDTH], Info.X2: mdat2[Info.X2]})
assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged
def test_merge_metadata_vertically(merge_test_metadata):
mdat1, mdat2, mdat_merged = merge_test_metadata
mdat2[Info.Y1] = mdat1[Info.Y2]
mdat2[Info.Y2] = mdat2[Info.Y1] + mdat2[Info.HEIGHT]
mdat_merged.update({Info.HEIGHT: mdat1[Info.HEIGHT] + mdat2[Info.HEIGHT], Info.Y2: mdat2[Info.Y2]})
assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged
@pytest.fixture
def merge_test_metadata(base_patch_metadata):
return juxt(*repeat(deepcopy, 3))(base_patch_metadata)
class Stitcher: # class Stitcher:
# @staticmethod
@staticmethod # def groupby(pairs, coord):
def groupby(pairs, coord): # coord_getter = make_coord_getter(coord)
coord_getter = make_coord_getter(coord) # pairs = sorted(pairs, key=coord_getter)
pairs = sorted(pairs, key=coord_getter) # return map(compose(list, second), groupby(pairs, coord_getter))
return map(compose(list, second), groupby(pairs, coord_getter)) #
# def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair: # groups = self.groupby(pairs, "x1")
groups = self.groupby(pairs, "x1") # groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups)) # groups = map(partial(sorted, key=y1_getter), groups)
groups = map(partial(sorted, key=make_coord_getter("y1")), groups) # groups = map(merge_group, groups)
groups = map(merge_group, groups)
@pytest.mark.parametrize("width", [160]) @pytest.mark.parametrize("width", [160])
@pytest.mark.parametrize("height", [90]) @pytest.mark.parametrize("height", [90])
@pytest.mark.parametrize("page_width", [int(160 * 1.1)]) @pytest.mark.parametrize("page_width", [int(160 * 1.1)])
@pytest.mark.parametrize("page_height", [int(90 * 1.1)]) @pytest.mark.parametrize("page_height", [int(90 * 1.1)])
@pytest.mark.skip()
def test_image_stitcher(patches_metadata, base_patch_metadata): def test_image_stitcher(patches_metadata, base_patch_metadata):
# noinspection PyTypeChecker # noinspection PyTypeChecker
assert Stitcher().stitch(patch_image_metadata_pairs).metadata == base_patch_metadata assert Stitcher().stitch(patch_image_metadata_pairs).metadata == base_patch_metadata

View File

@ -2,6 +2,7 @@ import abc
import random import random
from copy import deepcopy from copy import deepcopy
from itertools import chain from itertools import chain
from operator import itemgetter
from funcy import rpartial, juxt from funcy import rpartial, juxt
@ -9,25 +10,50 @@ from image_prediction.info import Info
class SplitKeyMapper(abc.ABC): class SplitKeyMapper(abc.ABC):
def __init__(self, box: dict, keymap: dict):
self.box = box def __init__(self, wrapped: dict, keymap: dict):
self.wrapped = wrapped
self.keymap = keymap self.keymap = keymap
def __getitem__(self, item): def __getitem__(self, key):
return self.box[self.keymap[item]] return self.wrapped[self.keymap[key]]
def __setitem__(self, key, value): def __setitem__(self, key, value):
self.box[self.keymap[key]] = value self.wrapped[self.keymap[key]] = value
@property
def dim(self):
return self.__getitem__("dim")
@dim.setter
def dim(self, value):
self.__setitem__("dim", value)
@property
def c1(self):
return self.__getitem__("c1")
@c1.setter
def c1(self, value):
self.__setitem__("c1", value)
@property
def c2(self):
return self.__getitem__("c2")
@c2.setter
def c2(self, value):
self.__setitem__("c2", value)
class HorizontalKeyMapper(SplitKeyMapper): class HorizontalKeyMapper(SplitKeyMapper):
def __init__(self, box: dict): def __init__(self, wrapped: dict):
super().__init__(box, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2}) super().__init__(wrapped, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2})
class VerticalKeyMapper(SplitKeyMapper): class VerticalKeyMapper(SplitKeyMapper):
def __init__(self, box: dict): def __init__(self, wrapped: dict):
super().__init__(box, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2}) super().__init__(wrapped, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2})
class BoxSplitter: class BoxSplitter:
@ -71,25 +97,25 @@ class BoxSplitter:
return ( return (
self.__get_child_boxes(wrapped_box) self.__get_child_boxes(wrapped_box)
if self.__large_enough(wrapped_box) if self.__large_enough(wrapped_box)
else self.__base_case(wrapped_box.box) else self.__base_case(wrapped_box.wrapped)
) )
@staticmethod @staticmethod
def __large_enough(wrapped_box: SplitKeyMapper): def __large_enough(wrapped_box: SplitKeyMapper):
return wrapped_box["dim"] >= 10 return wrapped_box.dim >= 10
@staticmethod @staticmethod
def __get_child_boxes(wrapped_box: SplitKeyMapper): def __get_child_boxes(wrapped_box: SplitKeyMapper):
split_len = random.randint(5, wrapped_box["dim"] - 5) split_len = random.randint(5, wrapped_box.dim - 5)
split_point = wrapped_box["c1"] + split_len split_point = wrapped_box.c1 + split_len
box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box) box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box)
box_left["dim"] = split_len box_left.dim = split_len
box_right["dim"] = wrapped_box["dim"] - split_len box_right.dim = wrapped_box.dim - split_len
box_left["c2"] = split_point box_left.c2 = split_point
box_right["c1"] = split_point box_right.c1 = split_point
return box_left.box, box_right.box return box_left.wrapped, box_right.wrapped