From 51793d19e93d083e26d6bb3b720b32105aee963d Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Thu, 7 Apr 2022 21:39:01 +0200 Subject: [PATCH] refactoring --- .../classifier/image_classifier.py | 3 +- .../extractor_classifier.py | 2 +- image_prediction/stitcher/__init__.py | 0 image_prediction/stitcher/stitcher.py | 53 +++++ image_prediction/stitcher/utils.py | 142 ++++++++++++ image_prediction/utils/__init__.py | 5 - image_prediction/utils/generic.py | 14 ++ test/unit_tests/image_classifier_test.py | 2 +- test/unit_tests/image_stitcher_test.py | 219 ++---------------- 9 files changed, 231 insertions(+), 209 deletions(-) create mode 100644 image_prediction/stitcher/__init__.py create mode 100644 image_prediction/stitcher/stitcher.py create mode 100644 image_prediction/stitcher/utils.py create mode 100644 image_prediction/utils/generic.py diff --git a/image_prediction/classifier/image_classifier.py b/image_prediction/classifier/image_classifier.py index fd8d6b2..1b9ca84 100644 --- a/image_prediction/classifier/image_classifier.py +++ b/image_prediction/classifier/image_classifier.py @@ -7,7 +7,8 @@ from funcy import rcompose from image_prediction.classifier.classifier import Classifier from image_prediction.estimator.preprocessor.preprocessor import Preprocessor from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor -from image_prediction.utils import chunk_iterable, get_logger +from image_prediction.utils import get_logger +from image_prediction.utils.generic import chunk_iterable logger = get_logger() diff --git a/image_prediction/extractor_classifier/extractor_classifier.py b/image_prediction/extractor_classifier/extractor_classifier.py index 95d217b..31b9c46 100644 --- a/image_prediction/extractor_classifier/extractor_classifier.py +++ b/image_prediction/extractor_classifier/extractor_classifier.py @@ -3,7 +3,7 @@ from typing import Iterable from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair -from image_prediction.utils import chunk_iterable +from image_prediction.utils.generic import chunk_iterable class ExtractorClassifier: diff --git a/image_prediction/stitcher/__init__.py b/image_prediction/stitcher/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/stitcher/stitcher.py b/image_prediction/stitcher/stitcher.py new file mode 100644 index 0000000..1f7d8a2 --- /dev/null +++ b/image_prediction/stitcher/stitcher.py @@ -0,0 +1,53 @@ +from itertools import groupby, chain +from typing import Iterable, List + +from funcy import compose, second + +from image_prediction.image_extractor.extractor import ImageMetadataPair +from image_prediction.stitcher.utils import make_coord_getter, make_group_merger +from image_prediction.utils.generic import until_convergence + + +class Stitcher: + @staticmethod + def groupby(pairs, coord_getter): + pairs = sorted(pairs, key=coord_getter) + return map(compose(list, second), groupby(pairs, coord_getter)) + + @staticmethod + def other_axis(axis): + return "y" if axis == "x" else "x" + + def merge_along_axis(self, pairs, axis): + def group_pairs_by_c1(pairs): + return self.groupby(pairs, c1_getter) + + def group_by_c2(pairs): + return self.groupby(pairs, c2_getter) + + def group_pairs_within_groups_by_c2(groups): + return map(group_by_c2, groups) + + def merge_groups_along_orthogonal_axis(groups): + return map(group_merger, groups) + + c1_getter = make_coord_getter(f"{self.other_axis(axis)}1") + c2_getter = make_coord_getter(f"{self.other_axis(axis)}2") + group_merger = make_group_merger(axis) + + groups_of_pairs_with_same_c1 = group_pairs_by_c1(pairs) + groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1) + groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2) + groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2) + pairs = chain(*groups_of_merged_pairs) + + return pairs + + def merge_along_both_axes(self, pairs): + pairs = self.merge_along_axis(pairs, "x") + pairs = list(self.merge_along_axis(pairs, "y")) + + return pairs + + def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]: + return until_convergence(self.merge_along_both_axes, pairs) diff --git a/image_prediction/stitcher/utils.py b/image_prediction/stitcher/utils.py new file mode 100644 index 0000000..ddbd0a3 --- /dev/null +++ b/image_prediction/stitcher/utils.py @@ -0,0 +1,142 @@ +from copy import deepcopy +from functools import reduce +from typing import Iterable + +from PIL import Image +from funcy import juxt, first, rest + +from image_prediction.image_extractor.extractor import ImageMetadataPair +from image_prediction.info import Info +from image_prediction.utils.generic import until_convergence +from test.utils.stitching import HorizontalKeyMapper, VerticalKeyMapper + + +def make_getter(key): + def getter(pair): + return pair.metadata[key] + + return getter + + +def make_coord_getter(c): + return { + "x1": make_getter(Info.X1), + "x2": make_getter(Info.X2), + "y1": make_getter(Info.Y1), + "y2": make_getter(Info.Y2), + }[c] + + +def make_pair_merger(axis): + return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis] + + +def make_group_merger(axis): + return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis] + + +def make_length_getter(dim): + return { + "width": make_getter(Info.WIDTH), + "height": make_getter(Info.HEIGHT), + }[dim] + + +def merge_metadata_horizontally(m1, m2): + m1, m2 = map(HorizontalKeyMapper, [m1, m2]) + return merge_metadata(m1, m2) + + +def merge_metadata_vertically(m1, m2): + m1, m2 = map(VerticalKeyMapper, [m1, m2]) + return merge_metadata(m1, m2) + + +def merge_metadata(m1, m2): + + c1 = min(m1.c1, m2.c1) + c2 = max(m1.c2, m2.c2) + dim = m1.dim + m2.dim + + merged = deepcopy(m1) + merged.dim = dim + merged.c1 = c1 + merged.c2 = c2 + + return merged.wrapped + + +def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair): + metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata) + image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged) + return ImageMetadataPair(image_concatenated, metadata_merged) + + +def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair): + metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata) + image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged) + return ImageMetadataPair(image_concatenated, metadata_merged) + + +def concat_images_horizontally(im1: Image, im2: Image, metadata: dict): + return concat_images(im1, im2, metadata, 0) + + +def concat_images_vertically(im1: Image, im2: Image, metadata: dict): + return concat_images(im1, im2, metadata, 1) + + +def concat_images(im1: Image, im2: Image, metadata: dict, axis): + + im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT])) + + images = [im1, im2] + + offsets = [0, *[im.size[axis] for im in images]] + + for im, offset in zip(images, offsets): + box = (offset, 0) if not axis else (0, offset) + im_aggr.paste(im, box=box) + + return im_aggr + + +def make_merger_aggregator(direction): + """Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the + head H and aggregates non-adjacent in the tail T. + """ + + def merger_aggregator(pairs): + def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair): + """Keeps the image that is being merged with as the head and aggregates non-mergables in the tail.""" + aggr, non_aggr = juxt(first, rest)(pairs_aggr) + if c2_getter(aggr) == c1_getter(pair): + aggr = pair_merger(aggr, pair) + return aggr, *non_aggr + else: + return aggr, pair, *non_aggr + + # Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens + # only in c1 -> c2 direction. + pairs = sorted(pairs, key=c1_getter) + aggregation_pair, pairs = juxt(first, rest)(pairs) + return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair])) + + c1_getter = make_coord_getter(f"{direction}1") + c2_getter = make_coord_getter(f"{direction}2") + pair_merger = make_pair_merger(direction) + + return merger_aggregator + + +def merge_group(group, direction): + reduce_group = make_merger_aggregator(direction) + return until_convergence(reduce_group, group) + + +def merge_group_horizontally(group): + return merge_group(group, "x") + + +def merge_group_vertically(group): + return merge_group(group, "y") diff --git a/image_prediction/utils/__init__.py b/image_prediction/utils/__init__.py index d8ef2e6..f9e558e 100644 --- a/image_prediction/utils/__init__.py +++ b/image_prediction/utils/__init__.py @@ -1,8 +1,3 @@ -from itertools import takewhile, starmap, islice, repeat -from operator import truth - from .logger import get_logger -def chunk_iterable(iterable, chunk_size): - return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size))))) diff --git a/image_prediction/utils/generic.py b/image_prediction/utils/generic.py new file mode 100644 index 0000000..aa7228a --- /dev/null +++ b/image_prediction/utils/generic.py @@ -0,0 +1,14 @@ +from itertools import takewhile, starmap, islice, repeat +from operator import truth + +from funcy import iterate + + +def chunk_iterable(iterable, chunk_size): + return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size))))) + + +def until_convergence(func, *args, **kwargs): + for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2): + if len(a) == len(b): + return a diff --git a/test/unit_tests/image_classifier_test.py b/test/unit_tests/image_classifier_test.py index 7dff19b..801030a 100644 --- a/test/unit_tests/image_classifier_test.py +++ b/test/unit_tests/image_classifier_test.py @@ -1,6 +1,6 @@ import pytest -from image_prediction.utils import chunk_iterable +from image_prediction.utils.generic import chunk_iterable @pytest.mark.parametrize("estimator_type", ["mock", "keras"]) diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index 1896fd9..9dabe6a 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -1,229 +1,46 @@ from copy import deepcopy -from functools import partial, reduce -from itertools import groupby -from itertools import starmap, chain, repeat -from typing import Iterable, List +from functools import partial +from itertools import starmap, repeat +from typing import List import fpdf import numpy as np import pdf2image import pytest from PIL import Image -from funcy import merge, second, compose, rpartial, juxt, rest, first, one, iterate +from funcy import merge, juxt, one from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info -from image_prediction.utils import chunk_iterable +from image_prediction.stitcher.stitcher import Stitcher +from image_prediction.stitcher.utils import ( + make_coord_getter, + make_length_getter, + merge_metadata_horizontally, + merge_metadata_vertically, + merge_pair_horizontally, + merge_pair_vertically, + concat_images_horizontally, + concat_images_vertically, + merge_group_horizontally, + merge_group_vertically, +) from test.conftest import ( get_base_position_metadata, add_image, random_single_color_image_from_metadata, random_size_gray_image_from_metadata, ) -from test.utils.stitching import BoxSplitter, VerticalKeyMapper, HorizontalKeyMapper - - -def make_getter(key): - def getter(pair): - return pair.metadata[key] - - return getter - - -def make_coord_getter(c): - return { - "x1": make_getter(Info.X1), - "x2": make_getter(Info.X2), - "y1": make_getter(Info.Y1), - "y2": make_getter(Info.Y2), - }[c] - - -def make_pair_merger(axis): - return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis] - - -def make_group_merger(axis): - return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis] - - -def make_length_getter(dim): - return { - "width": make_getter(Info.WIDTH), - "height": make_getter(Info.HEIGHT), - }[dim] - +from test.utils.stitching import BoxSplitter x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2")) width_getter, height_getter = map(make_length_getter, ("width", "height")) -def merge_metadata_horizontally(m1, m2): - m1, m2 = map(HorizontalKeyMapper, [m1, m2]) - return merge_metadata(m1, m2) - - -def merge_metadata_vertically(m1, m2): - m1, m2 = map(VerticalKeyMapper, [m1, m2]) - return merge_metadata(m1, m2) - - -def merge_metadata(m1, m2): - - c1 = min(m1.c1, m2.c1) - c2 = max(m1.c2, m2.c2) - dim = m1.dim + m2.dim - - merged = deepcopy(m1) - merged.dim = dim - merged.c1 = c1 - merged.c2 = c2 - - return merged.wrapped - - -def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair): - metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata) - image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged) - return ImageMetadataPair(image_concatenated, metadata_merged) - - -def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair): - metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata) - image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged) - return ImageMetadataPair(image_concatenated, metadata_merged) - - -# def merge_pair(p1, p2): -# assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX] - - -def concat_images_horizontally(im1: Image, im2: Image, metadata: dict): - return concat_images(im1, im2, metadata, 0) - - -def concat_images_vertically(im1: Image, im2: Image, metadata: dict): - return concat_images(im1, im2, metadata, 1) - - -def concat_images(im1: Image, im2: Image, metadata: dict, axis): - - im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT])) - - images = [im1, im2] - - offsets = [0, *[im.size[axis] for im in images]] - - for im, offset in zip(images, offsets): - box = (offset, 0) if not axis else (0, offset) - im_aggr.paste(im, box=box) - - return im_aggr - - -def make_merger_aggregator(direction): - """Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the - head H and aggregates non-adjacent in the tail T. - """ - - def merger_aggregator(pairs): - def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair): - """Keeps the image that is being merged with as the head and aggregates non-mergables in the tail.""" - aggr, non_aggr = juxt(first, rest)(pairs_aggr) - if c2_getter(aggr) == c1_getter(pair): - aggr = pair_merger(aggr, pair) - return aggr, *non_aggr - else: - return aggr, pair, *non_aggr - - # Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens - # only in c1 -> c2 direction. - pairs = sorted(pairs, key=c1_getter) - aggregation_pair, pairs = juxt(first, rest)(pairs) - return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair])) - - c1_getter = make_coord_getter(f"{direction}1") - c2_getter = make_coord_getter(f"{direction}2") - pair_merger = make_pair_merger(direction) - - return merger_aggregator - - -def merge_group(group, direction): - reduce_group = make_merger_aggregator(direction) - return until_convergence(reduce_group, group) - - -def merge_group_horizontally(group): - return merge_group(group, "x") - - -def merge_group_vertically(group): - return merge_group(group, "y") - - -class Stitcher: - - @staticmethod - def groupby(pairs, coord_getter): - pairs = sorted(pairs, key=coord_getter) - return map(compose(list, second), groupby(pairs, coord_getter)) - - @staticmethod - def other_axis(axis): - return "y" if axis == "x" else "x" - - def merge_along_axis(self, pairs, axis): - - def group_pairs_by_c1(pairs): - return self.groupby(pairs, c1_getter) - - def group_by_c2(pairs): - return self.groupby(pairs, c2_getter) - - def group_pairs_within_groups_by_c2(groups): - return map(group_by_c2, groups) - - def merge_groups_along_orthogonal_axis(groups): - return map(group_merger, groups) - - c1_getter = make_coord_getter(f"{self.other_axis(axis)}1") - c2_getter = make_coord_getter(f"{self.other_axis(axis)}2") - group_merger = make_group_merger(axis) - - groups_of_pairs_with_same_c1 = group_pairs_by_c1(pairs) - groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1) - groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2) - groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2) - pairs = chain(*groups_of_merged_pairs) - - return pairs - - def merge_along_both_axes(self, pairs): - - pairs = self.merge_along_axis(pairs, "x") - pairs = list(self.merge_along_axis(pairs, "y")) - - return pairs - - def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]: - return until_convergence(self.merge_along_both_axes, pairs) - - -def until_convergence(func, *args, **kwargs): - for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2): - if len(a) == len(b): - return a - - ##################################### -# @pytest.mark.parametrize("width", [160]) -# @pytest.mark.parametrize("height", [90]) -# @pytest.mark.parametrize("page_width", [int(160 * 1.1)]) -# @pytest.mark.parametrize("page_height", [int(90 * 1.1)]) def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image): pair_stitched = Stitcher().stitch(patch_image_metadata_pairs)[0] assert pair_stitched.metadata == base_patch_metadata