diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index 6dd4825..9f643e2 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -1,13 +1,12 @@ -import json import random from copy import deepcopy from itertools import chain +from operator import itemgetter import fpdf import pytest from funcy import juxt, merge, rpartial -from image_prediction.formatter.formatters.enum import EnumFormatter from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info from test.conftest import get_base_position_metadata, add_image, random_single_color_image_from_metadata @@ -22,7 +21,6 @@ def test_image_stitcher(partial_image_metadata_pairs): @pytest.mark.parametrize("page_width", [int(160 * 1.1)]) @pytest.mark.parametrize("page_height", [int(90 * 1.1)]) def test_partial_image_metadata_pairs(patches_metadata, page_width, page_height): - pdf = fpdf.FPDF(unit="pt", format=(page_width, page_height)) for patch in patches_metadata: @@ -32,45 +30,72 @@ def test_partial_image_metadata_pairs(patches_metadata, page_width, page_height) pdf.output("/tmp/bla.pdf") -def split_box(box, max_step=3): - def split_recursively(box, step): - def split_horizontal(): - return split(Info.WIDTH, Info.X1, Info.X2) +class BoxSplitter: + def __init__(self): + self.__steps = None + self.__horiz_keymap = {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2} + self.__vert_keymap = {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2} - def split_vertical(): - return split(Info.HEIGHT, Info.Y1, Info.Y2) + def split_box(self, box, steps=5): + self.__steps = steps + return self.__split_recursively(box, 0) - def split(dim, coord1, coord2): + def __split_recursively(self, box, step): + return self.__split_and_recurse(box, step) if self.__steps_left(step) else self.__base_case(box) - if box[dim] >= 10: - split_len = random.randint(5, box[dim] - 5) - split_point = box[coord1] + split_len + def __steps_left(self, step): + return step < self.__steps - box_left, box_right = juxt(deepcopy, deepcopy)(box) + @staticmethod + def __base_case(box): + return [box] - box_left[dim] = split_len - box_right[dim] = box[dim] - split_len + def __split_and_recurse(self, box, step): + new_boxes = self.__random_split(box) + return chain.from_iterable(self.__tree_recurse(new_boxes, step + 1)) - box_left[coord2] = split_point - box_right[coord1] = split_point + def __random_split(self, box): + splitter = random.choice([self.__split_horizontal, self.__split_vertical]) + new_boxes = splitter(box) + return new_boxes - return box_left, box_right - else: - return [box] + def __tree_recurse(self, boxes, step): + return map(rpartial(self.__split_recursively, step + 1), boxes) - if step < max_step: - new_boxes = random.choice([split_horizontal, split_vertical])() + def __split_horizontal(self, box): + return self.__split_if_large_enough(box, self.__horiz_keymap) - return chain.from_iterable(map(rpartial(split_recursively, step + 1), new_boxes)) - else: - return [box] + def __split_vertical(self, box): + return self.__split_if_large_enough(box, self.__vert_keymap) - return split_recursively(box, 0) + def __split_if_large_enough(self, box, keymap): + return self.__get_child_boxes(box, keymap) if self.__large_enough(box, keymap) else self.__base_case(box) + + @staticmethod + def __large_enough(box, keymap): + return box[keymap["dim"]] >= 10 + + @staticmethod + def __get_child_boxes(box, keymap): + dim, c1, c2 = itemgetter("dim", "c1", "c2")(keymap) + + split_len = random.randint(5, box[dim] - 5) + split_point = box[c1] + split_len + + box_left, box_right = juxt(deepcopy, deepcopy)(box) + + box_left[dim] = split_len + box_right[dim] = box[dim] - split_len + + box_left[c2] = split_point + box_right[c1] = split_point + + return box_left, box_right @pytest.fixture() def patches_metadata(width, height, page_width, page_height): box = get_base_position_metadata(width, height, page_width, page_height) box = merge(box, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height}) - boxes = split_box(box) + boxes = BoxSplitter().split_box(box) return boxes