From 4be91de03644c9e54c28865b590da0d47c20ad12 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Wed, 4 Jan 2023 15:20:07 +0100 Subject: [PATCH] Refactoring Further clean up Rectangle class --- cv_analysis/utils/structures.py | 91 +++++++++++++++++++++------------ 1 file changed, 59 insertions(+), 32 deletions(-) diff --git a/cv_analysis/utils/structures.py b/cv_analysis/utils/structures.py index e8ae2c2..c799df6 100644 --- a/cv_analysis/utils/structures.py +++ b/cv_analysis/utils/structures.py @@ -2,44 +2,58 @@ from __future__ import annotations from json import dumps -from typing import Iterable +from operator import itemgetter +from typing import Iterable, Union, Dict -import numpy as np from funcy import identity from cv_analysis.utils.spacial import adjacent, contains, intersection, iou, area, is_contained +Coord = Union[int, float] + class Rectangle: - def __init__( - self, - x1=None, - y1=None, - w=None, - h=None, - x2=None, - y2=None, - discrete=True, - ): + def __init__(self, x1, y1, x2, y2, discrete=True): + """Creates a rectangle from two points.""" nearest_valid = int if discrete else identity - try: - self.x1 = nearest_valid(x1) - self.y1 = nearest_valid(y1) - self.w = nearest_valid(w) if w else nearest_valid(x2 - x1) - self.h = nearest_valid(h) if h else nearest_valid(y2 - y1) - self.x2 = nearest_valid(x2) if x2 else self.x1 + self.w - self.y2 = nearest_valid(y2) if y2 else self.y1 + self.h - assert np.isclose(self.x1 + self.w, self.x2) - assert np.isclose(self.y1 + self.h, self.y2) - except Exception as err: - raise ValueError("x1, y1, (w|x2), and (h|y2) must be defined.") from err + self.__x1 = nearest_valid(x1) + self.__y1 = nearest_valid(y1) + self.__x2 = nearest_valid(x2) + self.__y2 = nearest_valid(y2) + + self.__w = nearest_valid(x2 - x1) + self.__h = nearest_valid(y2 - y1) + + @property + def x1(self): + return self.__x1 + + @property + def x2(self): + return self.__x2 + + @property + def y1(self): + return self.__y1 + + @property + def y2(self): + return self.__y2 + + @property + def w(self): + return self.__w + + @property + def h(self): + return self.__h def __str__(self): return dumps(self.json()) def __repr__(self): - return str(self.json()) + return str(self) def __iter__(self): return list(self.json().values()).__iter__() @@ -51,18 +65,25 @@ class Rectangle: return hash((self.x1, self.y1, self.x2, self.y2)) @classmethod - def from_xyxy(cls, xyxy_tuple, discrete=True): - x1, y1, x2, y2 = xyxy_tuple - return cls(x1=x1, y1=y1, x2=x2, y2=y2, discrete=discrete) + def from_xyxy(cls, xyxy: Iterable[Coord], discrete=True): + """Creates a rectangle from two points.""" + return cls(*xyxy, discrete=discrete) @classmethod - def from_xywh(cls, xywh_tuple, discrete=True): - x, y, w, h = xywh_tuple - return cls(x1=x, y1=y, w=w, h=h, discrete=discrete) + def from_xywh(cls, xywh: Iterable[Coord], discrete=True): + """Creates a rectangle from a point and a width and height.""" + x1, y1, w, h = xywh + x2 = x1 + w + y2 = y1 + h + return cls(x1, y1, x2, y2, discrete=discrete) @classmethod - def from_dict_xywh(cls, xywh_dict, discrete=True): - return cls(x1=xywh_dict["x"], y1=xywh_dict["y"], w=xywh_dict["width"], h=xywh_dict["height"], discrete=discrete) + def from_dict_xywh(cls, xywh: Dict[str, Coord], discrete=True): + """Creates a rectangle from a point and a width and height.""" + x1, y1, w, h = itemgetter("x", "y", "width", "height")(xywh) + x2 = x1 + w + y2 = y1 + h + return cls(x1, y1, x2, y2, discrete=discrete) def xyxy(self): return self.x1, self.y1, self.x2, self.y2 @@ -83,19 +104,25 @@ class Rectangle: return {"x": self.x1, "y": self.y1, "width": self.w, "height": self.h} def intersection(self, other): + """Calculates the intersection of this and another rectangle.""" return intersection(self, other) def area(self): + """Calculates the area of this rectangle.""" return area(self) def iou(self, other: Rectangle): + """Calculates the intersection over union of this and another rectangle.""" return iou(self, other) def includes(self, other: Rectangle, tol=3): + """Checks if this rectangle contains another.""" return contains(self, other, tol) def is_included(self, rectangles: Iterable[Rectangle]): + """Checks if this rectangle is contained by any of the given rectangles.""" return is_contained(self, rectangles) def adjacent(self, other: Rectangle, tolerance=7): + """Checks if this rectangle is adjacent to another.""" return adjacent(self, other, tolerance)