diff --git a/image_prediction/stitcher/stitcher.py b/image_prediction/stitcher/stitcher.py deleted file mode 100644 index 441ea2e..0000000 --- a/image_prediction/stitcher/stitcher.py +++ /dev/null @@ -1,65 +0,0 @@ -from itertools import groupby, chain -from typing import Iterable, List - -from funcy import compose, second - -from image_prediction.image_extractor.extractor import ImageMetadataPair -from image_prediction.stitcher.utils import make_coord_getter, make_group_merger -from image_prediction.utils.generic import until - - -def group_by_coordinate(pairs, coord_getter): - pairs = sorted(pairs, key=coord_getter) - return map(compose(list, second), groupby(pairs, coord_getter)) - - -class CoordGrouper: - def __init__(self, axis): - - self.c1_getter = make_coord_getter(f"{other_axis(axis)}1") - self.c2_getter = make_coord_getter(f"{other_axis(axis)}2") - - def get_c1(self, pair: ImageMetadataPair): - return self.c1_getter(pair) - - def group_pairs_by_c1(self, pairs): - return group_by_coordinate(pairs, self.c1_getter) - - def group_pairs_by_c2(self, pairs): - return group_by_coordinate(pairs, self.c2_getter) - - -def other_axis(axis): - return "y" if axis == "x" else "x" - - -class Stitcher: - def merge_along_axis(self, pairs, axis): - def group_pairs_within_groups_by_c2(groups): - return map(grouper.group_pairs_by_c2, groups) - - def merge_groups_along_orthogonal_axis(groups): - return map(group_merger, groups) - - grouper = CoordGrouper(axis) - group_merger = make_group_merger(axis) - - groups_of_pairs_with_same_c1 = grouper.group_pairs_by_c1(pairs) - groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1) - groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2) - groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2) - pairs = chain(*groups_of_merged_pairs) - - return pairs - - def merge_along_both_axes(self, pairs): - pairs = self.merge_along_axis(pairs, "x") - pairs = list(self.merge_along_axis(pairs, "y")) - - return pairs - - def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]: - def break_condition(pairs1, pairs2): - return len(pairs1) == len(pairs2) - - return until(self.merge_along_both_axes, break_condition, pairs) diff --git a/image_prediction/stitcher/__init__.py b/image_prediction/stitching/__init__.py similarity index 100% rename from image_prediction/stitcher/__init__.py rename to image_prediction/stitching/__init__.py diff --git a/image_prediction/stitching/grouping.py b/image_prediction/stitching/grouping.py new file mode 100644 index 0000000..8c56325 --- /dev/null +++ b/image_prediction/stitching/grouping.py @@ -0,0 +1,26 @@ +from itertools import groupby + +from funcy import compose, second + +from image_prediction.stitching.utils import make_coord_getter + + +def group_by_coordinate(pairs, coord_getter): + pairs = sorted(pairs, key=coord_getter) + return map(compose(list, second), groupby(pairs, coord_getter)) + + +class CoordGrouper: + def __init__(self, axis): + self.c1_getter = make_coord_getter(f"{other_axis(axis)}1") + self.c2_getter = make_coord_getter(f"{other_axis(axis)}2") + + def group_pairs_by_lesser_coordinate(self, pairs): + return group_by_coordinate(pairs, self.c1_getter) + + def group_pairs_by_greater_coordinate(self, pairs): + return group_by_coordinate(pairs, self.c2_getter) + + +def other_axis(axis): + return "y" if axis == "x" else "x" diff --git a/image_prediction/stitcher/utils.py b/image_prediction/stitching/merging.py similarity index 64% rename from image_prediction/stitcher/utils.py rename to image_prediction/stitching/merging.py index 81a6405..7cdd61e 100644 --- a/image_prediction/stitcher/utils.py +++ b/image_prediction/stitching/merging.py @@ -1,32 +1,18 @@ from copy import deepcopy from functools import reduce -from typing import Iterable +from typing import Iterable, Callable from PIL import Image -from funcy import juxt, first, rest +from funcy import juxt, first, rest, rcompose from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info +from image_prediction.stitching.grouping import CoordGrouper +from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once from image_prediction.utils.generic import until from test.utils.stitching import HorizontalKeyMapper, VerticalKeyMapper -def make_getter(key): - def getter(pair): - return pair.metadata[key] - - return getter - - -def make_coord_getter(c): - return { - "x1": make_getter(Info.X1), - "x2": make_getter(Info.X2), - "y1": make_getter(Info.Y1), - "y2": make_getter(Info.Y2), - }[c] - - def make_pair_merger(axis): return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis] @@ -35,24 +21,17 @@ def make_group_merger(axis): return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis] -def make_length_getter(dim): - return { - "width": make_getter(Info.WIDTH), - "height": make_getter(Info.HEIGHT), - }[dim] - - -def merge_metadata_horizontally(m1, m2): +def merge_metadata_horizontally(m1: dict, m2: dict): m1, m2 = map(HorizontalKeyMapper, [m1, m2]) return merge_metadata(m1, m2) -def merge_metadata_vertically(m1, m2): +def merge_metadata_vertically(m1: dict, m2: dict): m1, m2 = map(VerticalKeyMapper, [m1, m2]) return merge_metadata(m1, m2) -def merge_metadata(m1, m2): +def merge_metadata(m1: dict, m2: dict): c1 = min(m1.c1, m2.c1) c2 = max(m1.c2, m2.c2) @@ -101,12 +80,12 @@ def concat_images(im1: Image, im2: Image, metadata: dict, axis): return im_aggr -def make_merger_aggregator(direction): +def make_merger_aggregator(direction) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]: """Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the head H and aggregates non-adjacent in the tail T. """ - def merger_aggregator(pairs): + def merger_aggregator(pairs: Iterable[ImageMetadataPair]): def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair): """Keeps the image that is being merged with as the head and aggregates non-mergables in the tail.""" aggr, non_aggr = juxt(first, rest)(pairs_aggr) @@ -129,17 +108,44 @@ def make_merger_aggregator(direction): return merger_aggregator -def merge_group(group, direction): - def break_condition(pairs1, pairs2): - return len(pairs1) == len(pairs2) - +def merge_group(group: Iterable[ImageMetadataPair], direction): reduce_group = make_merger_aggregator(direction) - return until(reduce_group, break_condition, group) + return until(no_new_merges, reduce_group, group) -def merge_group_horizontally(group): +def merge_group_horizontally(group: Iterable[ImageMetadataPair]): return merge_group(group, "x") -def merge_group_vertically(group): +def merge_group_vertically(group: Iterable[ImageMetadataPair]): return merge_group(group, "y") + + +def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis): + def group_pairs_within_groups_by_greater_coordinate(groups): + return map(CoordGrouper(axis).group_pairs_by_greater_coordinate, groups) + + def merge_groups_along_orthogonal_axis(groups): + return map(make_group_merger(axis), groups) + + def group_pairs_by_lesser_coordinate(pairs): + return CoordGrouper(axis).group_pairs_by_lesser_coordinate(pairs) + + return rcompose( + group_pairs_by_lesser_coordinate, # pairs -> groups of pairs aligned on one edge + group_pairs_within_groups_by_greater_coordinate, # -> groups of pairs fully aligned on orthogonal axis + flatten_groups_once, # groups of groups of pairs -> groups of pairs + merge_groups_along_orthogonal_axis, + flatten_groups_once, # groups of pairs -> pairs + )(pairs) + + +def merge_along_both_axes(pairs: Iterable[ImageMetadataPair]): + pairs = merge_along_axis(pairs, "x") + pairs = list(merge_along_axis(pairs, "y")) + + return pairs + + +def no_new_merges(pairs1, pairs2): + return len(pairs1) == len(pairs2) diff --git a/image_prediction/stitching/stitching.py b/image_prediction/stitching/stitching.py new file mode 100644 index 0000000..bb8c519 --- /dev/null +++ b/image_prediction/stitching/stitching.py @@ -0,0 +1,11 @@ +from typing import Iterable + +from image_prediction.image_extractor.extractor import ImageMetadataPair +from image_prediction.stitching.merging import merge_along_both_axes, no_new_merges +from image_prediction.utils.generic import until + + +def stitch_pairs(pairs: Iterable[ImageMetadataPair]) -> Iterable[ImageMetadataPair]: + """Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent + images.""" + return until(no_new_merges, merge_along_both_axes, pairs) diff --git a/image_prediction/stitching/utils.py b/image_prediction/stitching/utils.py new file mode 100644 index 0000000..422c160 --- /dev/null +++ b/image_prediction/stitching/utils.py @@ -0,0 +1,30 @@ +from itertools import chain + +from image_prediction.info import Info + + +def make_getter(key): + def getter(pair): + return pair.metadata[key] + + return getter + + +def make_coord_getter(c): + return { + "x1": make_getter(Info.X1), + "x2": make_getter(Info.X2), + "y1": make_getter(Info.Y1), + "y2": make_getter(Info.Y2), + }[c] + + +def make_length_getter(dim): + return { + "width": make_getter(Info.WIDTH), + "height": make_getter(Info.HEIGHT), + }[dim] + + +def flatten_groups_once(groups): + return chain.from_iterable(groups) \ No newline at end of file diff --git a/image_prediction/utils/generic.py b/image_prediction/utils/generic.py index c02e016..4b7a2b0 100644 --- a/image_prediction/utils/generic.py +++ b/image_prediction/utils/generic.py @@ -8,7 +8,7 @@ def chunk_iterable(iterable, chunk_size): return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size))))) -def until(func, cond, *args, **kwargs): +def until(cond, func, *args, **kwargs): for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2): if cond(a, b): return a diff --git a/requirements.txt b/requirements.txt index 4ac4acb..bbfa69a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,5 +18,5 @@ PyMuPDF==1.19.6 fpdf==1.7.2 coverage==6.3.2 Pillow==9.1.0 -PDFNetPython3==9.2.0 +PDFNetPython3==9.1.0 pdf2image==1.16.0 \ No newline at end of file diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitching_test.py similarity index 92% rename from test/unit_tests/image_stitcher_test.py rename to test/unit_tests/image_stitching_test.py index 9dabe6a..e469e47 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitching_test.py @@ -13,19 +13,14 @@ from funcy import merge, juxt, one from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info -from image_prediction.stitcher.stitcher import Stitcher -from image_prediction.stitcher.utils import ( +from image_prediction.stitching.stitching import stitch_pairs +from image_prediction.stitching.utils import ( make_coord_getter, make_length_getter, - merge_metadata_horizontally, - merge_metadata_vertically, - merge_pair_horizontally, - merge_pair_vertically, - concat_images_horizontally, - concat_images_vertically, - merge_group_horizontally, - merge_group_vertically, ) +from image_prediction.stitching.merging import merge_metadata_horizontally, merge_metadata_vertically, \ + merge_pair_horizontally, merge_pair_vertically, concat_images_horizontally, concat_images_vertically, \ + merge_group_horizontally, merge_group_vertically from test.conftest import ( get_base_position_metadata, add_image, @@ -38,14 +33,9 @@ x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", width_getter, height_getter = map(make_length_getter, ("width", "height")) -##################################### - - def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image): - pair_stitched = Stitcher().stitch(patch_image_metadata_pairs)[0] + pair_stitched = stitch_pairs(patch_image_metadata_pairs)[0] assert pair_stitched.metadata == base_patch_metadata - # pair_stitched.image.show() - # base_patch_image.show() assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)