refactoring
This commit is contained in:
parent
1b10445f91
commit
3b18fc6158
@ -1,65 +0,0 @@
|
||||
from itertools import groupby, chain
|
||||
from typing import Iterable, List
|
||||
|
||||
from funcy import compose, second
|
||||
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.stitcher.utils import make_coord_getter, make_group_merger
|
||||
from image_prediction.utils.generic import until
|
||||
|
||||
|
||||
def group_by_coordinate(pairs, coord_getter):
|
||||
pairs = sorted(pairs, key=coord_getter)
|
||||
return map(compose(list, second), groupby(pairs, coord_getter))
|
||||
|
||||
|
||||
class CoordGrouper:
|
||||
def __init__(self, axis):
|
||||
|
||||
self.c1_getter = make_coord_getter(f"{other_axis(axis)}1")
|
||||
self.c2_getter = make_coord_getter(f"{other_axis(axis)}2")
|
||||
|
||||
def get_c1(self, pair: ImageMetadataPair):
|
||||
return self.c1_getter(pair)
|
||||
|
||||
def group_pairs_by_c1(self, pairs):
|
||||
return group_by_coordinate(pairs, self.c1_getter)
|
||||
|
||||
def group_pairs_by_c2(self, pairs):
|
||||
return group_by_coordinate(pairs, self.c2_getter)
|
||||
|
||||
|
||||
def other_axis(axis):
|
||||
return "y" if axis == "x" else "x"
|
||||
|
||||
|
||||
class Stitcher:
|
||||
def merge_along_axis(self, pairs, axis):
|
||||
def group_pairs_within_groups_by_c2(groups):
|
||||
return map(grouper.group_pairs_by_c2, groups)
|
||||
|
||||
def merge_groups_along_orthogonal_axis(groups):
|
||||
return map(group_merger, groups)
|
||||
|
||||
grouper = CoordGrouper(axis)
|
||||
group_merger = make_group_merger(axis)
|
||||
|
||||
groups_of_pairs_with_same_c1 = grouper.group_pairs_by_c1(pairs)
|
||||
groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1)
|
||||
groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2)
|
||||
groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2)
|
||||
pairs = chain(*groups_of_merged_pairs)
|
||||
|
||||
return pairs
|
||||
|
||||
def merge_along_both_axes(self, pairs):
|
||||
pairs = self.merge_along_axis(pairs, "x")
|
||||
pairs = list(self.merge_along_axis(pairs, "y"))
|
||||
|
||||
return pairs
|
||||
|
||||
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
|
||||
def break_condition(pairs1, pairs2):
|
||||
return len(pairs1) == len(pairs2)
|
||||
|
||||
return until(self.merge_along_both_axes, break_condition, pairs)
|
||||
26
image_prediction/stitching/grouping.py
Normal file
26
image_prediction/stitching/grouping.py
Normal file
@ -0,0 +1,26 @@
|
||||
from itertools import groupby
|
||||
|
||||
from funcy import compose, second
|
||||
|
||||
from image_prediction.stitching.utils import make_coord_getter
|
||||
|
||||
|
||||
def group_by_coordinate(pairs, coord_getter):
|
||||
pairs = sorted(pairs, key=coord_getter)
|
||||
return map(compose(list, second), groupby(pairs, coord_getter))
|
||||
|
||||
|
||||
class CoordGrouper:
|
||||
def __init__(self, axis):
|
||||
self.c1_getter = make_coord_getter(f"{other_axis(axis)}1")
|
||||
self.c2_getter = make_coord_getter(f"{other_axis(axis)}2")
|
||||
|
||||
def group_pairs_by_lesser_coordinate(self, pairs):
|
||||
return group_by_coordinate(pairs, self.c1_getter)
|
||||
|
||||
def group_pairs_by_greater_coordinate(self, pairs):
|
||||
return group_by_coordinate(pairs, self.c2_getter)
|
||||
|
||||
|
||||
def other_axis(axis):
|
||||
return "y" if axis == "x" else "x"
|
||||
@ -1,32 +1,18 @@
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Iterable
|
||||
from typing import Iterable, Callable
|
||||
|
||||
from PIL import Image
|
||||
from funcy import juxt, first, rest
|
||||
from funcy import juxt, first, rest, rcompose
|
||||
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.stitching.grouping import CoordGrouper
|
||||
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once
|
||||
from image_prediction.utils.generic import until
|
||||
from test.utils.stitching import HorizontalKeyMapper, VerticalKeyMapper
|
||||
|
||||
|
||||
def make_getter(key):
|
||||
def getter(pair):
|
||||
return pair.metadata[key]
|
||||
|
||||
return getter
|
||||
|
||||
|
||||
def make_coord_getter(c):
|
||||
return {
|
||||
"x1": make_getter(Info.X1),
|
||||
"x2": make_getter(Info.X2),
|
||||
"y1": make_getter(Info.Y1),
|
||||
"y2": make_getter(Info.Y2),
|
||||
}[c]
|
||||
|
||||
|
||||
def make_pair_merger(axis):
|
||||
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
|
||||
|
||||
@ -35,24 +21,17 @@ def make_group_merger(axis):
|
||||
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
|
||||
|
||||
|
||||
def make_length_getter(dim):
|
||||
return {
|
||||
"width": make_getter(Info.WIDTH),
|
||||
"height": make_getter(Info.HEIGHT),
|
||||
}[dim]
|
||||
|
||||
|
||||
def merge_metadata_horizontally(m1, m2):
|
||||
def merge_metadata_horizontally(m1: dict, m2: dict):
|
||||
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
|
||||
return merge_metadata(m1, m2)
|
||||
|
||||
|
||||
def merge_metadata_vertically(m1, m2):
|
||||
def merge_metadata_vertically(m1: dict, m2: dict):
|
||||
m1, m2 = map(VerticalKeyMapper, [m1, m2])
|
||||
return merge_metadata(m1, m2)
|
||||
|
||||
|
||||
def merge_metadata(m1, m2):
|
||||
def merge_metadata(m1: dict, m2: dict):
|
||||
|
||||
c1 = min(m1.c1, m2.c1)
|
||||
c2 = max(m1.c2, m2.c2)
|
||||
@ -101,12 +80,12 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
||||
return im_aggr
|
||||
|
||||
|
||||
def make_merger_aggregator(direction):
|
||||
def make_merger_aggregator(direction) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:
|
||||
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
||||
head H and aggregates non-adjacent in the tail T.
|
||||
"""
|
||||
|
||||
def merger_aggregator(pairs):
|
||||
def merger_aggregator(pairs: Iterable[ImageMetadataPair]):
|
||||
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
||||
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
||||
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
||||
@ -129,17 +108,44 @@ def make_merger_aggregator(direction):
|
||||
return merger_aggregator
|
||||
|
||||
|
||||
def merge_group(group, direction):
|
||||
def break_condition(pairs1, pairs2):
|
||||
return len(pairs1) == len(pairs2)
|
||||
|
||||
def merge_group(group: Iterable[ImageMetadataPair], direction):
|
||||
reduce_group = make_merger_aggregator(direction)
|
||||
return until(reduce_group, break_condition, group)
|
||||
return until(no_new_merges, reduce_group, group)
|
||||
|
||||
|
||||
def merge_group_horizontally(group):
|
||||
def merge_group_horizontally(group: Iterable[ImageMetadataPair]):
|
||||
return merge_group(group, "x")
|
||||
|
||||
|
||||
def merge_group_vertically(group):
|
||||
def merge_group_vertically(group: Iterable[ImageMetadataPair]):
|
||||
return merge_group(group, "y")
|
||||
|
||||
|
||||
def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis):
|
||||
def group_pairs_within_groups_by_greater_coordinate(groups):
|
||||
return map(CoordGrouper(axis).group_pairs_by_greater_coordinate, groups)
|
||||
|
||||
def merge_groups_along_orthogonal_axis(groups):
|
||||
return map(make_group_merger(axis), groups)
|
||||
|
||||
def group_pairs_by_lesser_coordinate(pairs):
|
||||
return CoordGrouper(axis).group_pairs_by_lesser_coordinate(pairs)
|
||||
|
||||
return rcompose(
|
||||
group_pairs_by_lesser_coordinate, # pairs -> groups of pairs aligned on one edge
|
||||
group_pairs_within_groups_by_greater_coordinate, # -> groups of pairs fully aligned on orthogonal axis
|
||||
flatten_groups_once, # groups of groups of pairs -> groups of pairs
|
||||
merge_groups_along_orthogonal_axis,
|
||||
flatten_groups_once, # groups of pairs -> pairs
|
||||
)(pairs)
|
||||
|
||||
|
||||
def merge_along_both_axes(pairs: Iterable[ImageMetadataPair]):
|
||||
pairs = merge_along_axis(pairs, "x")
|
||||
pairs = list(merge_along_axis(pairs, "y"))
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def no_new_merges(pairs1, pairs2):
|
||||
return len(pairs1) == len(pairs2)
|
||||
11
image_prediction/stitching/stitching.py
Normal file
11
image_prediction/stitching/stitching.py
Normal file
@ -0,0 +1,11 @@
|
||||
from typing import Iterable
|
||||
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.stitching.merging import merge_along_both_axes, no_new_merges
|
||||
from image_prediction.utils.generic import until
|
||||
|
||||
|
||||
def stitch_pairs(pairs: Iterable[ImageMetadataPair]) -> Iterable[ImageMetadataPair]:
|
||||
"""Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent
|
||||
images."""
|
||||
return until(no_new_merges, merge_along_both_axes, pairs)
|
||||
30
image_prediction/stitching/utils.py
Normal file
30
image_prediction/stitching/utils.py
Normal file
@ -0,0 +1,30 @@
|
||||
from itertools import chain
|
||||
|
||||
from image_prediction.info import Info
|
||||
|
||||
|
||||
def make_getter(key):
|
||||
def getter(pair):
|
||||
return pair.metadata[key]
|
||||
|
||||
return getter
|
||||
|
||||
|
||||
def make_coord_getter(c):
|
||||
return {
|
||||
"x1": make_getter(Info.X1),
|
||||
"x2": make_getter(Info.X2),
|
||||
"y1": make_getter(Info.Y1),
|
||||
"y2": make_getter(Info.Y2),
|
||||
}[c]
|
||||
|
||||
|
||||
def make_length_getter(dim):
|
||||
return {
|
||||
"width": make_getter(Info.WIDTH),
|
||||
"height": make_getter(Info.HEIGHT),
|
||||
}[dim]
|
||||
|
||||
|
||||
def flatten_groups_once(groups):
|
||||
return chain.from_iterable(groups)
|
||||
@ -8,7 +8,7 @@ def chunk_iterable(iterable, chunk_size):
|
||||
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
|
||||
|
||||
|
||||
def until(func, cond, *args, **kwargs):
|
||||
def until(cond, func, *args, **kwargs):
|
||||
for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2):
|
||||
if cond(a, b):
|
||||
return a
|
||||
|
||||
@ -18,5 +18,5 @@ PyMuPDF==1.19.6
|
||||
fpdf==1.7.2
|
||||
coverage==6.3.2
|
||||
Pillow==9.1.0
|
||||
PDFNetPython3==9.2.0
|
||||
PDFNetPython3==9.1.0
|
||||
pdf2image==1.16.0
|
||||
@ -13,19 +13,14 @@ from funcy import merge, juxt, one
|
||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.stitcher.stitcher import Stitcher
|
||||
from image_prediction.stitcher.utils import (
|
||||
from image_prediction.stitching.stitching import stitch_pairs
|
||||
from image_prediction.stitching.utils import (
|
||||
make_coord_getter,
|
||||
make_length_getter,
|
||||
merge_metadata_horizontally,
|
||||
merge_metadata_vertically,
|
||||
merge_pair_horizontally,
|
||||
merge_pair_vertically,
|
||||
concat_images_horizontally,
|
||||
concat_images_vertically,
|
||||
merge_group_horizontally,
|
||||
merge_group_vertically,
|
||||
)
|
||||
from image_prediction.stitching.merging import merge_metadata_horizontally, merge_metadata_vertically, \
|
||||
merge_pair_horizontally, merge_pair_vertically, concat_images_horizontally, concat_images_vertically, \
|
||||
merge_group_horizontally, merge_group_vertically
|
||||
from test.conftest import (
|
||||
get_base_position_metadata,
|
||||
add_image,
|
||||
@ -38,14 +33,9 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1",
|
||||
width_getter, height_getter = map(make_length_getter, ("width", "height"))
|
||||
|
||||
|
||||
#####################################
|
||||
|
||||
|
||||
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
|
||||
pair_stitched = Stitcher().stitch(patch_image_metadata_pairs)[0]
|
||||
pair_stitched = stitch_pairs(patch_image_metadata_pairs)[0]
|
||||
assert pair_stitched.metadata == base_patch_metadata
|
||||
# pair_stitched.image.show()
|
||||
# base_patch_image.show()
|
||||
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user