stitching impl wip

This commit is contained in:
Matthias Bisping 2022-04-05 23:39:17 +02:00
parent 302613bf2b
commit 7e2696d5c5

View File

@ -1,17 +1,73 @@
from itertools import starmap from functools import partial
from itertools import starmap, chain
from typing import Iterable, List
import fpdf import fpdf
import pytest import pytest
from funcy import merge from funcy import merge, second, compose, rpartial
from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info from image_prediction.info import Info
from test.conftest import get_base_position_metadata, add_image, random_single_color_image_from_metadata from test.conftest import get_base_position_metadata, add_image, random_single_color_image_from_metadata
from test.utils.stitching import BoxSplitter from test.utils.stitching import BoxSplitter
from itertools import groupby
# def test_image_stitcher(patches_metadata): def make_getter(key):
# assert Stitcher().stitch(patches) def getter(pair):
return pair.metadata[key]
return getter
def make_coord_getter(c):
return {
"x1": make_getter(Info.X1),
"x2": make_getter(Info.X2),
"y1": make_getter(Info.Y1),
"y2": make_getter(Info.Y2),
}[c]
def merge_group(group, axis="y"):
y1_getter, y2_getter = map(make_coord_getter, ("y1", "y2"))
group = list(group)
current_pair = group.pop(0)
for pair in group:
if y2_getter(current_pair) == y1_getter(pair):
current_box = merge_pair(current_pair, pair)
def merge_pair(p1, p2):
pass
class Stitcher:
@staticmethod
def groupby(pairs, coord):
coord_getter = make_coord_getter(coord)
pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter))
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> ImageMetadataPair:
groups = self.groupby(pairs, "x1")
groups = chain.from_iterable(map(rpartial(self.groupby, "x2"), groups))
groups = map(partial(sorted, key=make_coord_getter("y1")), groups)
groups = map(merge_group, groups)
@pytest.mark.parametrize("width", [160])
@pytest.mark.parametrize("height", [90])
@pytest.mark.parametrize("page_width", [int(160 * 1.1)])
@pytest.mark.parametrize("page_height", [int(90 * 1.1)])
def test_image_stitcher(patches_metadata, base_patch_metadata):
# noinspection PyTypeChecker
assert Stitcher().stitch(patch_image_metadata_pairs).metadata == base_patch_metadata
@pytest.mark.parametrize("width", [160]) @pytest.mark.parametrize("width", [160])
@ -29,9 +85,9 @@ def test_partial_image_metadata_pairs(patch_image_metadata_pairs, page_width, pa
@pytest.fixture @pytest.fixture
def patch_image_metadata_pairs(patches_metadata): def patch_image_metadata_pairs(patches_metadata) -> List[ImageMetadataPair]:
images = map(random_single_color_image_from_metadata, patches_metadata) images = map(random_single_color_image_from_metadata, patches_metadata)
return starmap(ImageMetadataPair, zip(images, patches_metadata)) return list(starmap(ImageMetadataPair, zip(images, patches_metadata)))
@pytest.fixture @pytest.fixture