122 lines
3.2 KiB
Python
122 lines
3.2 KiB
Python
import abc
|
|
import random
|
|
from copy import deepcopy
|
|
from itertools import chain
|
|
from operator import itemgetter
|
|
|
|
from funcy import rpartial, juxt
|
|
|
|
from image_prediction.info import Info
|
|
|
|
|
|
class SplitKeyMapper(abc.ABC):
|
|
|
|
def __init__(self, wrapped: dict, keymap: dict):
|
|
self.wrapped = wrapped
|
|
self.keymap = keymap
|
|
|
|
def __getitem__(self, key):
|
|
return self.wrapped[self.keymap[key]]
|
|
|
|
def __setitem__(self, key, value):
|
|
self.wrapped[self.keymap[key]] = value
|
|
|
|
@property
|
|
def dim(self):
|
|
return self.__getitem__("dim")
|
|
|
|
@dim.setter
|
|
def dim(self, value):
|
|
self.__setitem__("dim", value)
|
|
|
|
@property
|
|
def c1(self):
|
|
return self.__getitem__("c1")
|
|
|
|
@c1.setter
|
|
def c1(self, value):
|
|
self.__setitem__("c1", value)
|
|
|
|
@property
|
|
def c2(self):
|
|
return self.__getitem__("c2")
|
|
|
|
@c2.setter
|
|
def c2(self, value):
|
|
self.__setitem__("c2", value)
|
|
|
|
|
|
class HorizontalKeyMapper(SplitKeyMapper):
|
|
def __init__(self, wrapped: dict):
|
|
super().__init__(wrapped, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2})
|
|
|
|
|
|
class VerticalKeyMapper(SplitKeyMapper):
|
|
def __init__(self, wrapped: dict):
|
|
super().__init__(wrapped, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2})
|
|
|
|
|
|
class BoxSplitter:
|
|
def __init__(self):
|
|
self.__steps = None
|
|
|
|
def split_box(self, box, steps=5):
|
|
self.__steps = steps
|
|
return self.__split_recursively(box, 0)
|
|
|
|
def __split_recursively(self, box, step):
|
|
return self.__split_and_recurse(box, step) if self.__steps_left(step) else self.__base_case(box)
|
|
|
|
def __steps_left(self, step):
|
|
return step < self.__steps
|
|
|
|
@staticmethod
|
|
def __base_case(box):
|
|
return [box]
|
|
|
|
def __split_and_recurse(self, box, step):
|
|
new_boxes = self.__random_split(box)
|
|
new_boxes_per_branch = self.__tree_recurse(new_boxes, step + 1)
|
|
return chain.from_iterable(new_boxes_per_branch)
|
|
|
|
def __random_split(self, box):
|
|
splitter = random.choice([self.__split_horizontal, self.__split_vertical])
|
|
new_boxes = splitter(box)
|
|
return new_boxes
|
|
|
|
def __tree_recurse(self, boxes, step):
|
|
return map(rpartial(self.__split_recursively, step + 1), boxes)
|
|
|
|
def __split_horizontal(self, box):
|
|
return self.__split_if_large_enough(HorizontalKeyMapper(box))
|
|
|
|
def __split_vertical(self, box):
|
|
return self.__split_if_large_enough(VerticalKeyMapper(box))
|
|
|
|
def __split_if_large_enough(self, wrapped_box: SplitKeyMapper):
|
|
return (
|
|
self.__get_child_boxes(wrapped_box)
|
|
if self.__large_enough(wrapped_box)
|
|
else self.__base_case(wrapped_box.wrapped)
|
|
)
|
|
|
|
@staticmethod
|
|
def __large_enough(wrapped_box: SplitKeyMapper):
|
|
return wrapped_box.dim >= 10
|
|
|
|
@staticmethod
|
|
def __get_child_boxes(wrapped_box: SplitKeyMapper):
|
|
|
|
split_len = random.randint(5, wrapped_box.dim - 5)
|
|
split_point = wrapped_box.c1 + split_len
|
|
|
|
box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box)
|
|
|
|
box_left.dim = split_len
|
|
box_right.dim = wrapped_box.dim - split_len
|
|
|
|
box_left.c2 = split_point
|
|
box_right.c1 = split_point
|
|
|
|
return box_left.wrapped, box_right.wrapped
|