import abc import random from copy import deepcopy from itertools import chain from operator import itemgetter from funcy import rpartial, juxt from image_prediction.info import Info class SplitKeyMapper(abc.ABC): def __init__(self, wrapped: dict, keymap: dict): self.wrapped = wrapped self.keymap = keymap def __getitem__(self, key): return self.wrapped[self.keymap[key]] def __setitem__(self, key, value): self.wrapped[self.keymap[key]] = value @property def dim(self): return self.__getitem__("dim") @dim.setter def dim(self, value): self.__setitem__("dim", value) @property def c1(self): return self.__getitem__("c1") @c1.setter def c1(self, value): self.__setitem__("c1", value) @property def c2(self): return self.__getitem__("c2") @c2.setter def c2(self, value): self.__setitem__("c2", value) class HorizontalKeyMapper(SplitKeyMapper): def __init__(self, wrapped: dict): super().__init__(wrapped, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2}) class VerticalKeyMapper(SplitKeyMapper): def __init__(self, wrapped: dict): super().__init__(wrapped, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2}) class BoxSplitter: def __init__(self): self.__steps = None def split_box(self, box, steps=5): self.__steps = steps return self.__split_recursively(box, 0) def __split_recursively(self, box, step): return self.__split_and_recurse(box, step) if self.__steps_left(step) else self.__base_case(box) def __steps_left(self, step): return step < self.__steps @staticmethod def __base_case(box): return [box] def __split_and_recurse(self, box, step): new_boxes = self.__random_split(box) new_boxes_per_branch = self.__tree_recurse(new_boxes, step + 1) return chain.from_iterable(new_boxes_per_branch) def __random_split(self, box): splitter = random.choice([self.__split_horizontal, self.__split_vertical]) new_boxes = splitter(box) return new_boxes def __tree_recurse(self, boxes, step): return map(rpartial(self.__split_recursively, step + 1), boxes) def __split_horizontal(self, box): return self.__split_if_large_enough(HorizontalKeyMapper(box)) def __split_vertical(self, box): return self.__split_if_large_enough(VerticalKeyMapper(box)) def __split_if_large_enough(self, wrapped_box: SplitKeyMapper): return ( self.__get_child_boxes(wrapped_box) if self.__large_enough(wrapped_box) else self.__base_case(wrapped_box.wrapped) ) @staticmethod def __large_enough(wrapped_box: SplitKeyMapper): return wrapped_box.dim >= 10 @staticmethod def __get_child_boxes(wrapped_box: SplitKeyMapper): split_len = random.randint(5, wrapped_box.dim - 5) split_point = wrapped_box.c1 + split_len box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box) box_left.dim = split_len box_right.dim = wrapped_box.dim - split_len box_left.c2 = split_point box_right.c1 = split_point return box_left.wrapped, box_right.wrapped