refactoring

This commit is contained in:
Matthias Bisping 2022-04-11 09:53:32 +02:00
parent 710783a2f8
commit 57440f5106
2 changed files with 17 additions and 19 deletions

View File

@ -10,7 +10,7 @@ from image_prediction.info import Info
from image_prediction.stitching.grouping import CoordGrouper from image_prediction.stitching.grouping import CoordGrouper
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once
from image_prediction.utils.generic import until from image_prediction.utils.generic import until
from test.utils.stitching import HorizontalKeyMapper, VerticalKeyMapper from test.utils.stitching import HorizontalSplitMapper, VerticalMapper
def no_new_merges(pairs1, pairs2): def no_new_merges(pairs1, pairs2):
@ -120,12 +120,12 @@ def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
def merge_metadata_vertically(m1: dict, m2: dict): def merge_metadata_vertically(m1: dict, m2: dict):
m1, m2 = map(VerticalKeyMapper, [m1, m2]) m1, m2 = map(VerticalMapper, [m1, m2])
return merge_metadata(m1, m2) return merge_metadata(m1, m2)
def merge_metadata_horizontally(m1: dict, m2: dict): def merge_metadata_horizontally(m1: dict, m2: dict):
m1, m2 = map(HorizontalKeyMapper, [m1, m2]) m1, m2 = map(HorizontalSplitMapper, [m1, m2])
return merge_metadata(m1, m2) return merge_metadata(m1, m2)

View File

@ -2,15 +2,13 @@ import abc
import random import random
from copy import deepcopy from copy import deepcopy
from itertools import chain from itertools import chain
from operator import itemgetter
from funcy import rpartial, juxt from funcy import rpartial, juxt
from image_prediction.info import Info from image_prediction.info import Info
class SplitKeyMapper(abc.ABC): class SplitMapper(abc.ABC):
def __init__(self, wrapped: dict, keymap: dict): def __init__(self, wrapped: dict, keymap: dict):
self.wrapped = wrapped self.wrapped = wrapped
self.keymap = keymap self.keymap = keymap
@ -23,35 +21,35 @@ class SplitKeyMapper(abc.ABC):
@property @property
def dim(self): def dim(self):
return self.__getitem__("dim") return self["dim"]
@dim.setter @dim.setter
def dim(self, value): def dim(self, value):
self.__setitem__("dim", value) self["dim"] = value
@property @property
def c1(self): def c1(self):
return self.__getitem__("c1") return self["c1"]
@c1.setter @c1.setter
def c1(self, value): def c1(self, value):
self.__setitem__("c1", value) self["c1"] = value
@property @property
def c2(self): def c2(self):
return self.__getitem__("c2") return self["c2"]
@c2.setter @c2.setter
def c2(self, value): def c2(self, value):
self.__setitem__("c2", value) self["c2"] = value
class HorizontalKeyMapper(SplitKeyMapper): class HorizontalSplitMapper(SplitMapper):
def __init__(self, wrapped: dict): def __init__(self, wrapped: dict):
super().__init__(wrapped, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2}) super().__init__(wrapped, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2})
class VerticalKeyMapper(SplitKeyMapper): class VerticalMapper(SplitMapper):
def __init__(self, wrapped: dict): def __init__(self, wrapped: dict):
super().__init__(wrapped, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2}) super().__init__(wrapped, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2})
@ -88,12 +86,12 @@ class BoxSplitter:
return map(rpartial(self.__split_recursively, step + 1), boxes) return map(rpartial(self.__split_recursively, step + 1), boxes)
def __split_horizontal(self, box): def __split_horizontal(self, box):
return self.__split_if_large_enough(HorizontalKeyMapper(box)) return self.__split_if_large_enough(HorizontalSplitMapper(box))
def __split_vertical(self, box): def __split_vertical(self, box):
return self.__split_if_large_enough(VerticalKeyMapper(box)) return self.__split_if_large_enough(VerticalMapper(box))
def __split_if_large_enough(self, wrapped_box: SplitKeyMapper): def __split_if_large_enough(self, wrapped_box: SplitMapper):
return ( return (
self.__get_child_boxes(wrapped_box) self.__get_child_boxes(wrapped_box)
if self.__large_enough(wrapped_box) if self.__large_enough(wrapped_box)
@ -101,11 +99,11 @@ class BoxSplitter:
) )
@staticmethod @staticmethod
def __large_enough(wrapped_box: SplitKeyMapper): def __large_enough(wrapped_box: SplitMapper):
return wrapped_box.dim >= 10 return wrapped_box.dim >= 10
@staticmethod @staticmethod
def __get_child_boxes(wrapped_box: SplitKeyMapper): def __get_child_boxes(wrapped_box: SplitMapper):
split_len = random.randint(5, wrapped_box.dim - 5) split_len = random.randint(5, wrapped_box.dim - 5)
split_point = wrapped_box.c1 + split_len split_point = wrapped_box.c1 + split_len