refactoring: replaced split mapper with dataclass

This commit is contained in:
Matthias Bisping 2022-04-11 12:16:42 +02:00
parent 1bea5fb9a8
commit f4c0547405
3 changed files with 30 additions and 38 deletions

View File

@ -10,7 +10,7 @@ from image_prediction.info import Info
from image_prediction.stitching.grouping import CoordGrouper from image_prediction.stitching.grouping import CoordGrouper
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once
from image_prediction.utils.generic import until from image_prediction.utils.generic import until
from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalMapper from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalSplitMapper
def no_new_merges(pairs1, pairs2): def no_new_merges(pairs1, pairs2):
@ -120,7 +120,7 @@ def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
def merge_metadata_vertically(m1: dict, m2: dict): def merge_metadata_vertically(m1: dict, m2: dict):
m1, m2 = map(VerticalMapper, [m1, m2]) m1, m2 = map(VerticalSplitMapper, [m1, m2])
return merge_metadata(m1, m2) return merge_metadata(m1, m2)

View File

@ -1,49 +1,41 @@
import abc from copy import deepcopy
from dataclasses import field, dataclass
from operator import attrgetter
from image_prediction.info import Info from image_prediction.info import Info
class SplitMapper(abc.ABC): @dataclass
def __init__(self, wrapped: dict, keymap: dict): class SplitMapper:
self.wrapped = wrapped """Manages access into a coordinate encoding mapping by abstracting over x1, x2 and y1, y2 as c1, c2; as well as
self.keymap = keymap over width and height as 'dim'."""
__keymap: dict
wrapped: dict
__wrapped: dict = field(init=False)
dim: float = field(init=False)
c1: float = field(init=False)
c2: float = field(init=False)
def __getitem__(self, key): def __post_init__(self):
return self.wrapped[self.keymap[key]] for k, v in self.__keymap.items():
setattr(self, k, self.__wrapped[v])
def __setitem__(self, key, value):
self.wrapped[self.keymap[key]] = value
@property @property
def dim(self): def wrapped(self):
return self["dim"] ret = deepcopy(self.__wrapped)
ret.update(dict(zip(self.__keymap.values(), attrgetter(*self.__keymap.keys())(self))))
return ret
@dim.setter @wrapped.setter
def dim(self, value): def wrapped(self, wrapped):
self["dim"] = value self.__wrapped = wrapped
@property
def c1(self):
return self["c1"]
@c1.setter
def c1(self, value):
self["c1"] = value
@property
def c2(self):
return self["c2"]
@c2.setter
def c2(self, value):
self["c2"] = value
class HorizontalSplitMapper(SplitMapper): class HorizontalSplitMapper(SplitMapper):
def __init__(self, wrapped: dict): def __init__(self, wrapped: dict):
super().__init__(wrapped, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2}) super().__init__({"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2}, wrapped)
class VerticalMapper(SplitMapper): class VerticalSplitMapper(SplitMapper):
def __init__(self, wrapped: dict): def __init__(self, wrapped: dict):
super().__init__(wrapped, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2}) super().__init__({"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2}, wrapped)

View File

@ -4,7 +4,7 @@ from itertools import chain
from funcy import rpartial, juxt from funcy import rpartial, juxt
from image_prediction.stitching.split_mapper import SplitMapper, HorizontalSplitMapper, VerticalMapper from image_prediction.stitching.split_mapper import SplitMapper, HorizontalSplitMapper, VerticalSplitMapper
class BoxSplitter: class BoxSplitter:
@ -42,7 +42,7 @@ class BoxSplitter:
return self.__split_if_large_enough(HorizontalSplitMapper(box)) return self.__split_if_large_enough(HorizontalSplitMapper(box))
def __split_vertical(self, box): def __split_vertical(self, box):
return self.__split_if_large_enough(VerticalMapper(box)) return self.__split_if_large_enough(VerticalSplitMapper(box))
def __split_if_large_enough(self, wrapped_box: SplitMapper): def __split_if_large_enough(self, wrapped_box: SplitMapper):
return ( return (