diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index ff0bb9f..894c3e5 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -1,3 +1,4 @@ +import abc import random from copy import deepcopy from itertools import chain @@ -5,7 +6,7 @@ from operator import itemgetter import fpdf import pytest -from funcy import juxt, merge, rpartial +from funcy import juxt, merge, rpartial, compose from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info @@ -30,11 +31,31 @@ def test_partial_image_metadata_pairs(patches_metadata, page_width, page_height) pdf.output("/tmp/bla.pdf") +class SplitKeyMapper(abc.ABC): + def __init__(self, box: dict, keymap: dict): + self.box = box + self.keymap = keymap + + def __getitem__(self, item): + return self.box[self.keymap[item]] + + def __setitem__(self, key, value): + self.box[self.keymap[key]] = value + + +class HorizontalKeyMapper(SplitKeyMapper): + def __init__(self, box: dict): + super().__init__(box, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2}) + + +class VerticalKeyMapper(SplitKeyMapper): + def __init__(self, box: dict): + super().__init__(box, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2}) + + class BoxSplitter: def __init__(self): self.__steps = None - self.__horiz_keymap = {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2} - self.__vert_keymap = {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2} def split_box(self, box, steps=5): self.__steps = steps @@ -64,34 +85,33 @@ class BoxSplitter: return map(rpartial(self.__split_recursively, step + 1), boxes) def __split_horizontal(self, box): - return self.__split_if_large_enough(box, self.__horiz_keymap) + return self.__split_if_large_enough(HorizontalKeyMapper(box)) def __split_vertical(self, box): - return self.__split_if_large_enough(box, self.__vert_keymap) + return self.__split_if_large_enough(VerticalKeyMapper(box)) - def __split_if_large_enough(self, box, keymap): - return self.__get_child_boxes(box, keymap) if self.__large_enough(box, keymap) else self.__base_case(box) + def __split_if_large_enough(self, sabox: SplitKeyMapper): + return self.__get_child_boxes(sabox) if self.__large_enough(sabox) else self.__base_case(sabox.box) @staticmethod - def __large_enough(box, keymap): - return box[keymap["dim"]] >= 10 + def __large_enough(box): + return box["dim"] >= 10 @staticmethod - def __get_child_boxes(box, keymap): - dim, c1, c2 = itemgetter("dim", "c1", "c2")(keymap) + def __get_child_boxes(sabox: SplitKeyMapper): - split_len = random.randint(5, box[dim] - 5) - split_point = box[c1] + split_len + split_len = random.randint(5, sabox["dim"] - 5) + split_point = sabox["c1"] + split_len - box_left, box_right = juxt(deepcopy, deepcopy)(box) + box_left, box_right = juxt(deepcopy, deepcopy)(sabox) - box_left[dim] = split_len - box_right[dim] = box[dim] - split_len + box_left["dim"] = split_len + box_right["dim"] = sabox["dim"] - split_len - box_left[c2] = split_point - box_right[c1] = split_point + box_left["c2"] = split_point + box_right["c1"] = split_point - return box_left, box_right + return box_left.box, box_right.box @pytest.fixture()