diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index 894c3e5..2e70c84 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -1,16 +1,11 @@ -import abc -import random -from copy import deepcopy -from itertools import chain -from operator import itemgetter - import fpdf import pytest -from funcy import juxt, merge, rpartial, compose +from funcy import merge from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info from test.conftest import get_base_position_metadata, add_image, random_single_color_image_from_metadata +from test.utils.stitching import BoxSplitter def test_image_stitcher(partial_image_metadata_pairs): @@ -31,89 +26,6 @@ def test_partial_image_metadata_pairs(patches_metadata, page_width, page_height) pdf.output("/tmp/bla.pdf") -class SplitKeyMapper(abc.ABC): - def __init__(self, box: dict, keymap: dict): - self.box = box - self.keymap = keymap - - def __getitem__(self, item): - return self.box[self.keymap[item]] - - def __setitem__(self, key, value): - self.box[self.keymap[key]] = value - - -class HorizontalKeyMapper(SplitKeyMapper): - def __init__(self, box: dict): - super().__init__(box, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2}) - - -class VerticalKeyMapper(SplitKeyMapper): - def __init__(self, box: dict): - super().__init__(box, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2}) - - -class BoxSplitter: - def __init__(self): - self.__steps = None - - def split_box(self, box, steps=5): - self.__steps = steps - return self.__split_recursively(box, 0) - - def __split_recursively(self, box, step): - return self.__split_and_recurse(box, step) if self.__steps_left(step) else self.__base_case(box) - - def __steps_left(self, step): - return step < self.__steps - - @staticmethod - def __base_case(box): - return [box] - - def __split_and_recurse(self, box, step): - new_boxes = self.__random_split(box) - new_boxes_per_branch = self.__tree_recurse(new_boxes, step + 1) - return chain.from_iterable(new_boxes_per_branch) - - def __random_split(self, box): - splitter = random.choice([self.__split_horizontal, self.__split_vertical]) - new_boxes = splitter(box) - return new_boxes - - def __tree_recurse(self, boxes, step): - return map(rpartial(self.__split_recursively, step + 1), boxes) - - def __split_horizontal(self, box): - return self.__split_if_large_enough(HorizontalKeyMapper(box)) - - def __split_vertical(self, box): - return self.__split_if_large_enough(VerticalKeyMapper(box)) - - def __split_if_large_enough(self, sabox: SplitKeyMapper): - return self.__get_child_boxes(sabox) if self.__large_enough(sabox) else self.__base_case(sabox.box) - - @staticmethod - def __large_enough(box): - return box["dim"] >= 10 - - @staticmethod - def __get_child_boxes(sabox: SplitKeyMapper): - - split_len = random.randint(5, sabox["dim"] - 5) - split_point = sabox["c1"] + split_len - - box_left, box_right = juxt(deepcopy, deepcopy)(sabox) - - box_left["dim"] = split_len - box_right["dim"] = sabox["dim"] - split_len - - box_left["c2"] = split_point - box_right["c1"] = split_point - - return box_left.box, box_right.box - - @pytest.fixture() def patches_metadata(width, height, page_width, page_height): box = get_base_position_metadata(width, height, page_width, page_height) diff --git a/test/utils/__init__.py b/test/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/utils/stitching.py b/test/utils/stitching.py new file mode 100644 index 0000000..e9d435c --- /dev/null +++ b/test/utils/stitching.py @@ -0,0 +1,91 @@ +import abc +import random +from copy import deepcopy +from itertools import chain + +from funcy import rpartial, juxt + +from image_prediction.info import Info + + +class SplitKeyMapper(abc.ABC): + def __init__(self, box: dict, keymap: dict): + self.box = box + self.keymap = keymap + + def __getitem__(self, item): + return self.box[self.keymap[item]] + + def __setitem__(self, key, value): + self.box[self.keymap[key]] = value + + +class HorizontalKeyMapper(SplitKeyMapper): + def __init__(self, box: dict): + super().__init__(box, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2}) + + +class VerticalKeyMapper(SplitKeyMapper): + def __init__(self, box: dict): + super().__init__(box, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2}) + + +class BoxSplitter: + def __init__(self): + self.__steps = None + + def split_box(self, box, steps=5): + self.__steps = steps + return self.__split_recursively(box, 0) + + def __split_recursively(self, box, step): + return self.__split_and_recurse(box, step) if self.__steps_left(step) else self.__base_case(box) + + def __steps_left(self, step): + return step < self.__steps + + @staticmethod + def __base_case(box): + return [box] + + def __split_and_recurse(self, box, step): + new_boxes = self.__random_split(box) + new_boxes_per_branch = self.__tree_recurse(new_boxes, step + 1) + return chain.from_iterable(new_boxes_per_branch) + + def __random_split(self, box): + splitter = random.choice([self.__split_horizontal, self.__split_vertical]) + new_boxes = splitter(box) + return new_boxes + + def __tree_recurse(self, boxes, step): + return map(rpartial(self.__split_recursively, step + 1), boxes) + + def __split_horizontal(self, box): + return self.__split_if_large_enough(HorizontalKeyMapper(box)) + + def __split_vertical(self, box): + return self.__split_if_large_enough(VerticalKeyMapper(box)) + + def __split_if_large_enough(self, sabox: SplitKeyMapper): + return self.__get_child_boxes(sabox) if self.__large_enough(sabox) else self.__base_case(sabox.box) + + @staticmethod + def __large_enough(box): + return box["dim"] >= 10 + + @staticmethod + def __get_child_boxes(sabox: SplitKeyMapper): + + split_len = random.randint(5, sabox["dim"] - 5) + split_point = sabox["c1"] + split_len + + box_left, box_right = juxt(deepcopy, deepcopy)(sabox) + + box_left["dim"] = split_len + box_right["dim"] = sabox["dim"] - split_len + + box_left["c2"] = split_point + box_right["c1"] = split_point + + return box_left.box, box_right.box \ No newline at end of file