Matthias Bisping 315679468b applied black
2022-04-05 19:35:36 +02:00

92 lines
2.6 KiB
Python

import abc
import random
from copy import deepcopy
from itertools import chain
from funcy import rpartial, juxt
from image_prediction.info import Info
class SplitKeyMapper(abc.ABC):
def __init__(self, box: dict, keymap: dict):
self.box = box
self.keymap = keymap
def __getitem__(self, item):
return self.box[self.keymap[item]]
def __setitem__(self, key, value):
self.box[self.keymap[key]] = value
class HorizontalKeyMapper(SplitKeyMapper):
def __init__(self, box: dict):
super().__init__(box, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2})
class VerticalKeyMapper(SplitKeyMapper):
def __init__(self, box: dict):
super().__init__(box, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2})
class BoxSplitter:
def __init__(self):
self.__steps = None
def split_box(self, box, steps=5):
self.__steps = steps
return self.__split_recursively(box, 0)
def __split_recursively(self, box, step):
return self.__split_and_recurse(box, step) if self.__steps_left(step) else self.__base_case(box)
def __steps_left(self, step):
return step < self.__steps
@staticmethod
def __base_case(box):
return [box]
def __split_and_recurse(self, box, step):
new_boxes = self.__random_split(box)
new_boxes_per_branch = self.__tree_recurse(new_boxes, step + 1)
return chain.from_iterable(new_boxes_per_branch)
def __random_split(self, box):
splitter = random.choice([self.__split_horizontal, self.__split_vertical])
new_boxes = splitter(box)
return new_boxes
def __tree_recurse(self, boxes, step)
return map(rpartial(self.__split_recursively, step + 1), boxes)
def __split_horizontal(self, box):
return self.__split_if_large_enough(HorizontalKeyMapper(box))
def __split_vertical(self, box):
return self.__split_if_large_enough(VerticalKeyMapper(box))
def __split_if_large_enough(self, sabox: SplitKeyMapper):
return self.__get_child_boxes(sabox) if self.__large_enough(sabox) else self.__base_case(sabox.box)
@staticmethod
def __large_enough(box):
return box["dim"] >= 10
@staticmethod
def __get_child_boxes(sabox: SplitKeyMapper):
split_len = random.randint(5, sabox["dim"] - 5)
split_point = sabox["c1"] + split_len
box_left, box_right = juxt(deepcopy, deepcopy)(sabox)
box_left["dim"] = split_len
box_right["dim"] = sabox["dim"] - split_len
box_left["c2"] = split_point
box_right["c1"] = split_point
return box_left.box, box_right.box