refactoring

This commit is contained in:
Matthias Bisping 2022-04-08 13:56:57 +02:00
parent 1b10445f91
commit 3b18fc6158
9 changed files with 118 additions and 120 deletions

View File

@ -1,65 +0,0 @@
from itertools import groupby, chain
from typing import Iterable, List
from funcy import compose, second
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.stitcher.utils import make_coord_getter, make_group_merger
from image_prediction.utils.generic import until
def group_by_coordinate(pairs, coord_getter):
pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter))
class CoordGrouper:
def __init__(self, axis):
self.c1_getter = make_coord_getter(f"{other_axis(axis)}1")
self.c2_getter = make_coord_getter(f"{other_axis(axis)}2")
def get_c1(self, pair: ImageMetadataPair):
return self.c1_getter(pair)
def group_pairs_by_c1(self, pairs):
return group_by_coordinate(pairs, self.c1_getter)
def group_pairs_by_c2(self, pairs):
return group_by_coordinate(pairs, self.c2_getter)
def other_axis(axis):
return "y" if axis == "x" else "x"
class Stitcher:
def merge_along_axis(self, pairs, axis):
def group_pairs_within_groups_by_c2(groups):
return map(grouper.group_pairs_by_c2, groups)
def merge_groups_along_orthogonal_axis(groups):
return map(group_merger, groups)
grouper = CoordGrouper(axis)
group_merger = make_group_merger(axis)
groups_of_pairs_with_same_c1 = grouper.group_pairs_by_c1(pairs)
groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1)
groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2)
groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2)
pairs = chain(*groups_of_merged_pairs)
return pairs
def merge_along_both_axes(self, pairs):
pairs = self.merge_along_axis(pairs, "x")
pairs = list(self.merge_along_axis(pairs, "y"))
return pairs
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
def break_condition(pairs1, pairs2):
return len(pairs1) == len(pairs2)
return until(self.merge_along_both_axes, break_condition, pairs)

View File

@ -0,0 +1,26 @@
from itertools import groupby
from funcy import compose, second
from image_prediction.stitching.utils import make_coord_getter
def group_by_coordinate(pairs, coord_getter):
pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter))
class CoordGrouper:
def __init__(self, axis):
self.c1_getter = make_coord_getter(f"{other_axis(axis)}1")
self.c2_getter = make_coord_getter(f"{other_axis(axis)}2")
def group_pairs_by_lesser_coordinate(self, pairs):
return group_by_coordinate(pairs, self.c1_getter)
def group_pairs_by_greater_coordinate(self, pairs):
return group_by_coordinate(pairs, self.c2_getter)
def other_axis(axis):
return "y" if axis == "x" else "x"

View File

@ -1,32 +1,18 @@
from copy import deepcopy
from functools import reduce
from typing import Iterable
from typing import Iterable, Callable
from PIL import Image
from funcy import juxt, first, rest
from funcy import juxt, first, rest, rcompose
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.stitching.grouping import CoordGrouper
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once
from image_prediction.utils.generic import until
from test.utils.stitching import HorizontalKeyMapper, VerticalKeyMapper
def make_getter(key):
def getter(pair):
return pair.metadata[key]
return getter
def make_coord_getter(c):
return {
"x1": make_getter(Info.X1),
"x2": make_getter(Info.X2),
"y1": make_getter(Info.Y1),
"y2": make_getter(Info.Y2),
}[c]
def make_pair_merger(axis):
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
@ -35,24 +21,17 @@ def make_group_merger(axis):
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
def make_length_getter(dim):
return {
"width": make_getter(Info.WIDTH),
"height": make_getter(Info.HEIGHT),
}[dim]
def merge_metadata_horizontally(m1, m2):
def merge_metadata_horizontally(m1: dict, m2: dict):
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata_vertically(m1, m2):
def merge_metadata_vertically(m1: dict, m2: dict):
m1, m2 = map(VerticalKeyMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata(m1, m2):
def merge_metadata(m1: dict, m2: dict):
c1 = min(m1.c1, m2.c1)
c2 = max(m1.c2, m2.c2)
@ -101,12 +80,12 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
return im_aggr
def make_merger_aggregator(direction):
def make_merger_aggregator(direction) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
head H and aggregates non-adjacent in the tail T.
"""
def merger_aggregator(pairs):
def merger_aggregator(pairs: Iterable[ImageMetadataPair]):
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
@ -129,17 +108,44 @@ def make_merger_aggregator(direction):
return merger_aggregator
def merge_group(group, direction):
def break_condition(pairs1, pairs2):
return len(pairs1) == len(pairs2)
def merge_group(group: Iterable[ImageMetadataPair], direction):
reduce_group = make_merger_aggregator(direction)
return until(reduce_group, break_condition, group)
return until(no_new_merges, reduce_group, group)
def merge_group_horizontally(group):
def merge_group_horizontally(group: Iterable[ImageMetadataPair]):
return merge_group(group, "x")
def merge_group_vertically(group):
def merge_group_vertically(group: Iterable[ImageMetadataPair]):
return merge_group(group, "y")
def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis):
def group_pairs_within_groups_by_greater_coordinate(groups):
return map(CoordGrouper(axis).group_pairs_by_greater_coordinate, groups)
def merge_groups_along_orthogonal_axis(groups):
return map(make_group_merger(axis), groups)
def group_pairs_by_lesser_coordinate(pairs):
return CoordGrouper(axis).group_pairs_by_lesser_coordinate(pairs)
return rcompose(
group_pairs_by_lesser_coordinate, # pairs -> groups of pairs aligned on one edge
group_pairs_within_groups_by_greater_coordinate, # -> groups of pairs fully aligned on orthogonal axis
flatten_groups_once, # groups of groups of pairs -> groups of pairs
merge_groups_along_orthogonal_axis,
flatten_groups_once, # groups of pairs -> pairs
)(pairs)
def merge_along_both_axes(pairs: Iterable[ImageMetadataPair]):
pairs = merge_along_axis(pairs, "x")
pairs = list(merge_along_axis(pairs, "y"))
return pairs
def no_new_merges(pairs1, pairs2):
return len(pairs1) == len(pairs2)

View File

@ -0,0 +1,11 @@
from typing import Iterable
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.stitching.merging import merge_along_both_axes, no_new_merges
from image_prediction.utils.generic import until
def stitch_pairs(pairs: Iterable[ImageMetadataPair]) -> Iterable[ImageMetadataPair]:
"""Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent
images."""
return until(no_new_merges, merge_along_both_axes, pairs)

View File

@ -0,0 +1,30 @@
from itertools import chain
from image_prediction.info import Info
def make_getter(key):
def getter(pair):
return pair.metadata[key]
return getter
def make_coord_getter(c):
return {
"x1": make_getter(Info.X1),
"x2": make_getter(Info.X2),
"y1": make_getter(Info.Y1),
"y2": make_getter(Info.Y2),
}[c]
def make_length_getter(dim):
return {
"width": make_getter(Info.WIDTH),
"height": make_getter(Info.HEIGHT),
}[dim]
def flatten_groups_once(groups):
return chain.from_iterable(groups)

View File

@ -8,7 +8,7 @@ def chunk_iterable(iterable, chunk_size):
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
def until(func, cond, *args, **kwargs):
def until(cond, func, *args, **kwargs):
for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2):
if cond(a, b):
return a

View File

@ -18,5 +18,5 @@ PyMuPDF==1.19.6
fpdf==1.7.2
coverage==6.3.2
Pillow==9.1.0
PDFNetPython3==9.2.0
PDFNetPython3==9.1.0
pdf2image==1.16.0

View File

@ -13,19 +13,14 @@ from funcy import merge, juxt, one
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.stitcher.stitcher import Stitcher
from image_prediction.stitcher.utils import (
from image_prediction.stitching.stitching import stitch_pairs
from image_prediction.stitching.utils import (
make_coord_getter,
make_length_getter,
merge_metadata_horizontally,
merge_metadata_vertically,
merge_pair_horizontally,
merge_pair_vertically,
concat_images_horizontally,
concat_images_vertically,
merge_group_horizontally,
merge_group_vertically,
)
from image_prediction.stitching.merging import merge_metadata_horizontally, merge_metadata_vertically, \
merge_pair_horizontally, merge_pair_vertically, concat_images_horizontally, concat_images_vertically, \
merge_group_horizontally, merge_group_vertically
from test.conftest import (
get_base_position_metadata,
add_image,
@ -38,14 +33,9 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1",
width_getter, height_getter = map(make_length_getter, ("width", "height"))
#####################################
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
pair_stitched = Stitcher().stitch(patch_image_metadata_pairs)[0]
pair_stitched = stitch_pairs(patch_image_metadata_pairs)[0]
assert pair_stitched.metadata == base_patch_metadata
# pair_stitched.image.show()
# base_patch_image.show()
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)