From f4c05474056a4708a91640d8dde38b2e14f0f3d2 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Mon, 11 Apr 2022 12:16:42 +0200 Subject: [PATCH] refactoring: replaced split mapper with dataclass --- image_prediction/stitching/merging.py | 4 +- image_prediction/stitching/split_mapper.py | 60 ++++++++++------------ test/utils/stitching.py | 4 +- 3 files changed, 30 insertions(+), 38 deletions(-) diff --git a/image_prediction/stitching/merging.py b/image_prediction/stitching/merging.py index 87bc673..5ec9d27 100644 --- a/image_prediction/stitching/merging.py +++ b/image_prediction/stitching/merging.py @@ -10,7 +10,7 @@ from image_prediction.info import Info from image_prediction.stitching.grouping import CoordGrouper from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once from image_prediction.utils.generic import until -from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalMapper +from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalSplitMapper def no_new_merges(pairs1, pairs2): @@ -120,7 +120,7 @@ def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair): def merge_metadata_vertically(m1: dict, m2: dict): - m1, m2 = map(VerticalMapper, [m1, m2]) + m1, m2 = map(VerticalSplitMapper, [m1, m2]) return merge_metadata(m1, m2) diff --git a/image_prediction/stitching/split_mapper.py b/image_prediction/stitching/split_mapper.py index 95b1470..da93871 100644 --- a/image_prediction/stitching/split_mapper.py +++ b/image_prediction/stitching/split_mapper.py @@ -1,49 +1,41 @@ -import abc +from copy import deepcopy +from dataclasses import field, dataclass +from operator import attrgetter from image_prediction.info import Info -class SplitMapper(abc.ABC): - def __init__(self, wrapped: dict, keymap: dict): - self.wrapped = wrapped - self.keymap = keymap +@dataclass +class SplitMapper: + """Manages access into a coordinate encoding mapping by abstracting over x1, x2 and y1, y2 as c1, c2; as well as + over width and height as 'dim'.""" + __keymap: dict + wrapped: dict + __wrapped: dict = field(init=False) + dim: float = field(init=False) + c1: float = field(init=False) + c2: float = field(init=False) - def __getitem__(self, key): - return self.wrapped[self.keymap[key]] - - def __setitem__(self, key, value): - self.wrapped[self.keymap[key]] = value + def __post_init__(self): + for k, v in self.__keymap.items(): + setattr(self, k, self.__wrapped[v]) @property - def dim(self): - return self["dim"] + def wrapped(self): + ret = deepcopy(self.__wrapped) + ret.update(dict(zip(self.__keymap.values(), attrgetter(*self.__keymap.keys())(self)))) + return ret - @dim.setter - def dim(self, value): - self["dim"] = value - - @property - def c1(self): - return self["c1"] - - @c1.setter - def c1(self, value): - self["c1"] = value - - @property - def c2(self): - return self["c2"] - - @c2.setter - def c2(self, value): - self["c2"] = value + @wrapped.setter + def wrapped(self, wrapped): + self.__wrapped = wrapped class HorizontalSplitMapper(SplitMapper): def __init__(self, wrapped: dict): - super().__init__(wrapped, {"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2}) + super().__init__({"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2}, wrapped) -class VerticalMapper(SplitMapper): +class VerticalSplitMapper(SplitMapper): def __init__(self, wrapped: dict): - super().__init__(wrapped, {"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2}) \ No newline at end of file + super().__init__({"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2}, wrapped) diff --git a/test/utils/stitching.py b/test/utils/stitching.py index eb41797..dc4cc0d 100644 --- a/test/utils/stitching.py +++ b/test/utils/stitching.py @@ -4,7 +4,7 @@ from itertools import chain from funcy import rpartial, juxt -from image_prediction.stitching.split_mapper import SplitMapper, HorizontalSplitMapper, VerticalMapper +from image_prediction.stitching.split_mapper import SplitMapper, HorizontalSplitMapper, VerticalSplitMapper class BoxSplitter: @@ -42,7 +42,7 @@ class BoxSplitter: return self.__split_if_large_enough(HorizontalSplitMapper(box)) def __split_vertical(self, box): - return self.__split_if_large_enough(VerticalMapper(box)) + return self.__split_if_large_enough(VerticalSplitMapper(box)) def __split_if_large_enough(self, wrapped_box: SplitMapper): return (