refactoring
This commit is contained in:
parent
1b10445f91
commit
3b18fc6158
@ -1,65 +0,0 @@
|
|||||||
from itertools import groupby, chain
|
|
||||||
from typing import Iterable, List
|
|
||||||
|
|
||||||
from funcy import compose, second
|
|
||||||
|
|
||||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
|
||||||
from image_prediction.stitcher.utils import make_coord_getter, make_group_merger
|
|
||||||
from image_prediction.utils.generic import until
|
|
||||||
|
|
||||||
|
|
||||||
def group_by_coordinate(pairs, coord_getter):
|
|
||||||
pairs = sorted(pairs, key=coord_getter)
|
|
||||||
return map(compose(list, second), groupby(pairs, coord_getter))
|
|
||||||
|
|
||||||
|
|
||||||
class CoordGrouper:
|
|
||||||
def __init__(self, axis):
|
|
||||||
|
|
||||||
self.c1_getter = make_coord_getter(f"{other_axis(axis)}1")
|
|
||||||
self.c2_getter = make_coord_getter(f"{other_axis(axis)}2")
|
|
||||||
|
|
||||||
def get_c1(self, pair: ImageMetadataPair):
|
|
||||||
return self.c1_getter(pair)
|
|
||||||
|
|
||||||
def group_pairs_by_c1(self, pairs):
|
|
||||||
return group_by_coordinate(pairs, self.c1_getter)
|
|
||||||
|
|
||||||
def group_pairs_by_c2(self, pairs):
|
|
||||||
return group_by_coordinate(pairs, self.c2_getter)
|
|
||||||
|
|
||||||
|
|
||||||
def other_axis(axis):
|
|
||||||
return "y" if axis == "x" else "x"
|
|
||||||
|
|
||||||
|
|
||||||
class Stitcher:
|
|
||||||
def merge_along_axis(self, pairs, axis):
|
|
||||||
def group_pairs_within_groups_by_c2(groups):
|
|
||||||
return map(grouper.group_pairs_by_c2, groups)
|
|
||||||
|
|
||||||
def merge_groups_along_orthogonal_axis(groups):
|
|
||||||
return map(group_merger, groups)
|
|
||||||
|
|
||||||
grouper = CoordGrouper(axis)
|
|
||||||
group_merger = make_group_merger(axis)
|
|
||||||
|
|
||||||
groups_of_pairs_with_same_c1 = grouper.group_pairs_by_c1(pairs)
|
|
||||||
groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1)
|
|
||||||
groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2)
|
|
||||||
groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2)
|
|
||||||
pairs = chain(*groups_of_merged_pairs)
|
|
||||||
|
|
||||||
return pairs
|
|
||||||
|
|
||||||
def merge_along_both_axes(self, pairs):
|
|
||||||
pairs = self.merge_along_axis(pairs, "x")
|
|
||||||
pairs = list(self.merge_along_axis(pairs, "y"))
|
|
||||||
|
|
||||||
return pairs
|
|
||||||
|
|
||||||
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
|
|
||||||
def break_condition(pairs1, pairs2):
|
|
||||||
return len(pairs1) == len(pairs2)
|
|
||||||
|
|
||||||
return until(self.merge_along_both_axes, break_condition, pairs)
|
|
||||||
26
image_prediction/stitching/grouping.py
Normal file
26
image_prediction/stitching/grouping.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from itertools import groupby
|
||||||
|
|
||||||
|
from funcy import compose, second
|
||||||
|
|
||||||
|
from image_prediction.stitching.utils import make_coord_getter
|
||||||
|
|
||||||
|
|
||||||
|
def group_by_coordinate(pairs, coord_getter):
|
||||||
|
pairs = sorted(pairs, key=coord_getter)
|
||||||
|
return map(compose(list, second), groupby(pairs, coord_getter))
|
||||||
|
|
||||||
|
|
||||||
|
class CoordGrouper:
|
||||||
|
def __init__(self, axis):
|
||||||
|
self.c1_getter = make_coord_getter(f"{other_axis(axis)}1")
|
||||||
|
self.c2_getter = make_coord_getter(f"{other_axis(axis)}2")
|
||||||
|
|
||||||
|
def group_pairs_by_lesser_coordinate(self, pairs):
|
||||||
|
return group_by_coordinate(pairs, self.c1_getter)
|
||||||
|
|
||||||
|
def group_pairs_by_greater_coordinate(self, pairs):
|
||||||
|
return group_by_coordinate(pairs, self.c2_getter)
|
||||||
|
|
||||||
|
|
||||||
|
def other_axis(axis):
|
||||||
|
return "y" if axis == "x" else "x"
|
||||||
@ -1,32 +1,18 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Iterable
|
from typing import Iterable, Callable
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from funcy import juxt, first, rest
|
from funcy import juxt, first, rest, rcompose
|
||||||
|
|
||||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
from image_prediction.info import Info
|
from image_prediction.info import Info
|
||||||
|
from image_prediction.stitching.grouping import CoordGrouper
|
||||||
|
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once
|
||||||
from image_prediction.utils.generic import until
|
from image_prediction.utils.generic import until
|
||||||
from test.utils.stitching import HorizontalKeyMapper, VerticalKeyMapper
|
from test.utils.stitching import HorizontalKeyMapper, VerticalKeyMapper
|
||||||
|
|
||||||
|
|
||||||
def make_getter(key):
|
|
||||||
def getter(pair):
|
|
||||||
return pair.metadata[key]
|
|
||||||
|
|
||||||
return getter
|
|
||||||
|
|
||||||
|
|
||||||
def make_coord_getter(c):
|
|
||||||
return {
|
|
||||||
"x1": make_getter(Info.X1),
|
|
||||||
"x2": make_getter(Info.X2),
|
|
||||||
"y1": make_getter(Info.Y1),
|
|
||||||
"y2": make_getter(Info.Y2),
|
|
||||||
}[c]
|
|
||||||
|
|
||||||
|
|
||||||
def make_pair_merger(axis):
|
def make_pair_merger(axis):
|
||||||
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
|
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
|
||||||
|
|
||||||
@ -35,24 +21,17 @@ def make_group_merger(axis):
|
|||||||
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
|
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
|
||||||
|
|
||||||
|
|
||||||
def make_length_getter(dim):
|
def merge_metadata_horizontally(m1: dict, m2: dict):
|
||||||
return {
|
|
||||||
"width": make_getter(Info.WIDTH),
|
|
||||||
"height": make_getter(Info.HEIGHT),
|
|
||||||
}[dim]
|
|
||||||
|
|
||||||
|
|
||||||
def merge_metadata_horizontally(m1, m2):
|
|
||||||
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
|
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
|
||||||
return merge_metadata(m1, m2)
|
return merge_metadata(m1, m2)
|
||||||
|
|
||||||
|
|
||||||
def merge_metadata_vertically(m1, m2):
|
def merge_metadata_vertically(m1: dict, m2: dict):
|
||||||
m1, m2 = map(VerticalKeyMapper, [m1, m2])
|
m1, m2 = map(VerticalKeyMapper, [m1, m2])
|
||||||
return merge_metadata(m1, m2)
|
return merge_metadata(m1, m2)
|
||||||
|
|
||||||
|
|
||||||
def merge_metadata(m1, m2):
|
def merge_metadata(m1: dict, m2: dict):
|
||||||
|
|
||||||
c1 = min(m1.c1, m2.c1)
|
c1 = min(m1.c1, m2.c1)
|
||||||
c2 = max(m1.c2, m2.c2)
|
c2 = max(m1.c2, m2.c2)
|
||||||
@ -101,12 +80,12 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
|||||||
return im_aggr
|
return im_aggr
|
||||||
|
|
||||||
|
|
||||||
def make_merger_aggregator(direction):
|
def make_merger_aggregator(direction) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:
|
||||||
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
||||||
head H and aggregates non-adjacent in the tail T.
|
head H and aggregates non-adjacent in the tail T.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def merger_aggregator(pairs):
|
def merger_aggregator(pairs: Iterable[ImageMetadataPair]):
|
||||||
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
||||||
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
||||||
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
||||||
@ -129,17 +108,44 @@ def make_merger_aggregator(direction):
|
|||||||
return merger_aggregator
|
return merger_aggregator
|
||||||
|
|
||||||
|
|
||||||
def merge_group(group, direction):
|
def merge_group(group: Iterable[ImageMetadataPair], direction):
|
||||||
def break_condition(pairs1, pairs2):
|
|
||||||
return len(pairs1) == len(pairs2)
|
|
||||||
|
|
||||||
reduce_group = make_merger_aggregator(direction)
|
reduce_group = make_merger_aggregator(direction)
|
||||||
return until(reduce_group, break_condition, group)
|
return until(no_new_merges, reduce_group, group)
|
||||||
|
|
||||||
|
|
||||||
def merge_group_horizontally(group):
|
def merge_group_horizontally(group: Iterable[ImageMetadataPair]):
|
||||||
return merge_group(group, "x")
|
return merge_group(group, "x")
|
||||||
|
|
||||||
|
|
||||||
def merge_group_vertically(group):
|
def merge_group_vertically(group: Iterable[ImageMetadataPair]):
|
||||||
return merge_group(group, "y")
|
return merge_group(group, "y")
|
||||||
|
|
||||||
|
|
||||||
|
def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis):
|
||||||
|
def group_pairs_within_groups_by_greater_coordinate(groups):
|
||||||
|
return map(CoordGrouper(axis).group_pairs_by_greater_coordinate, groups)
|
||||||
|
|
||||||
|
def merge_groups_along_orthogonal_axis(groups):
|
||||||
|
return map(make_group_merger(axis), groups)
|
||||||
|
|
||||||
|
def group_pairs_by_lesser_coordinate(pairs):
|
||||||
|
return CoordGrouper(axis).group_pairs_by_lesser_coordinate(pairs)
|
||||||
|
|
||||||
|
return rcompose(
|
||||||
|
group_pairs_by_lesser_coordinate, # pairs -> groups of pairs aligned on one edge
|
||||||
|
group_pairs_within_groups_by_greater_coordinate, # -> groups of pairs fully aligned on orthogonal axis
|
||||||
|
flatten_groups_once, # groups of groups of pairs -> groups of pairs
|
||||||
|
merge_groups_along_orthogonal_axis,
|
||||||
|
flatten_groups_once, # groups of pairs -> pairs
|
||||||
|
)(pairs)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_along_both_axes(pairs: Iterable[ImageMetadataPair]):
|
||||||
|
pairs = merge_along_axis(pairs, "x")
|
||||||
|
pairs = list(merge_along_axis(pairs, "y"))
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
def no_new_merges(pairs1, pairs2):
|
||||||
|
return len(pairs1) == len(pairs2)
|
||||||
11
image_prediction/stitching/stitching.py
Normal file
11
image_prediction/stitching/stitching.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
|
from image_prediction.stitching.merging import merge_along_both_axes, no_new_merges
|
||||||
|
from image_prediction.utils.generic import until
|
||||||
|
|
||||||
|
|
||||||
|
def stitch_pairs(pairs: Iterable[ImageMetadataPair]) -> Iterable[ImageMetadataPair]:
|
||||||
|
"""Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent
|
||||||
|
images."""
|
||||||
|
return until(no_new_merges, merge_along_both_axes, pairs)
|
||||||
30
image_prediction/stitching/utils.py
Normal file
30
image_prediction/stitching/utils.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
from image_prediction.info import Info
|
||||||
|
|
||||||
|
|
||||||
|
def make_getter(key):
|
||||||
|
def getter(pair):
|
||||||
|
return pair.metadata[key]
|
||||||
|
|
||||||
|
return getter
|
||||||
|
|
||||||
|
|
||||||
|
def make_coord_getter(c):
|
||||||
|
return {
|
||||||
|
"x1": make_getter(Info.X1),
|
||||||
|
"x2": make_getter(Info.X2),
|
||||||
|
"y1": make_getter(Info.Y1),
|
||||||
|
"y2": make_getter(Info.Y2),
|
||||||
|
}[c]
|
||||||
|
|
||||||
|
|
||||||
|
def make_length_getter(dim):
|
||||||
|
return {
|
||||||
|
"width": make_getter(Info.WIDTH),
|
||||||
|
"height": make_getter(Info.HEIGHT),
|
||||||
|
}[dim]
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_groups_once(groups):
|
||||||
|
return chain.from_iterable(groups)
|
||||||
@ -8,7 +8,7 @@ def chunk_iterable(iterable, chunk_size):
|
|||||||
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
|
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
|
||||||
|
|
||||||
|
|
||||||
def until(func, cond, *args, **kwargs):
|
def until(cond, func, *args, **kwargs):
|
||||||
for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2):
|
for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2):
|
||||||
if cond(a, b):
|
if cond(a, b):
|
||||||
return a
|
return a
|
||||||
|
|||||||
@ -18,5 +18,5 @@ PyMuPDF==1.19.6
|
|||||||
fpdf==1.7.2
|
fpdf==1.7.2
|
||||||
coverage==6.3.2
|
coverage==6.3.2
|
||||||
Pillow==9.1.0
|
Pillow==9.1.0
|
||||||
PDFNetPython3==9.2.0
|
PDFNetPython3==9.1.0
|
||||||
pdf2image==1.16.0
|
pdf2image==1.16.0
|
||||||
@ -13,19 +13,14 @@ from funcy import merge, juxt, one
|
|||||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
||||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
from image_prediction.info import Info
|
from image_prediction.info import Info
|
||||||
from image_prediction.stitcher.stitcher import Stitcher
|
from image_prediction.stitching.stitching import stitch_pairs
|
||||||
from image_prediction.stitcher.utils import (
|
from image_prediction.stitching.utils import (
|
||||||
make_coord_getter,
|
make_coord_getter,
|
||||||
make_length_getter,
|
make_length_getter,
|
||||||
merge_metadata_horizontally,
|
|
||||||
merge_metadata_vertically,
|
|
||||||
merge_pair_horizontally,
|
|
||||||
merge_pair_vertically,
|
|
||||||
concat_images_horizontally,
|
|
||||||
concat_images_vertically,
|
|
||||||
merge_group_horizontally,
|
|
||||||
merge_group_vertically,
|
|
||||||
)
|
)
|
||||||
|
from image_prediction.stitching.merging import merge_metadata_horizontally, merge_metadata_vertically, \
|
||||||
|
merge_pair_horizontally, merge_pair_vertically, concat_images_horizontally, concat_images_vertically, \
|
||||||
|
merge_group_horizontally, merge_group_vertically
|
||||||
from test.conftest import (
|
from test.conftest import (
|
||||||
get_base_position_metadata,
|
get_base_position_metadata,
|
||||||
add_image,
|
add_image,
|
||||||
@ -38,14 +33,9 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1",
|
|||||||
width_getter, height_getter = map(make_length_getter, ("width", "height"))
|
width_getter, height_getter = map(make_length_getter, ("width", "height"))
|
||||||
|
|
||||||
|
|
||||||
#####################################
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
|
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
|
||||||
pair_stitched = Stitcher().stitch(patch_image_metadata_pairs)[0]
|
pair_stitched = stitch_pairs(patch_image_metadata_pairs)[0]
|
||||||
assert pair_stitched.metadata == base_patch_metadata
|
assert pair_stitched.metadata == base_patch_metadata
|
||||||
# pair_stitched.image.show()
|
|
||||||
# base_patch_image.show()
|
|
||||||
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
||||||
|
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user