Refactoring

Further clean up Rectangle class
This commit is contained in:
Matthias Bisping 2023-01-04 15:20:07 +01:00
parent 8c6b940364
commit 4be91de036

View File

@ -2,44 +2,58 @@
from __future__ import annotations
from json import dumps
from typing import Iterable
from operator import itemgetter
from typing import Iterable, Union, Dict
import numpy as np
from funcy import identity
from cv_analysis.utils.spacial import adjacent, contains, intersection, iou, area, is_contained
Coord = Union[int, float]
class Rectangle:
def __init__(
self,
x1=None,
y1=None,
w=None,
h=None,
x2=None,
y2=None,
discrete=True,
):
def __init__(self, x1, y1, x2, y2, discrete=True):
"""Creates a rectangle from two points."""
nearest_valid = int if discrete else identity
try:
self.x1 = nearest_valid(x1)
self.y1 = nearest_valid(y1)
self.w = nearest_valid(w) if w else nearest_valid(x2 - x1)
self.h = nearest_valid(h) if h else nearest_valid(y2 - y1)
self.x2 = nearest_valid(x2) if x2 else self.x1 + self.w
self.y2 = nearest_valid(y2) if y2 else self.y1 + self.h
assert np.isclose(self.x1 + self.w, self.x2)
assert np.isclose(self.y1 + self.h, self.y2)
except Exception as err:
raise ValueError("x1, y1, (w|x2), and (h|y2) must be defined.") from err
self.__x1 = nearest_valid(x1)
self.__y1 = nearest_valid(y1)
self.__x2 = nearest_valid(x2)
self.__y2 = nearest_valid(y2)
self.__w = nearest_valid(x2 - x1)
self.__h = nearest_valid(y2 - y1)
@property
def x1(self):
return self.__x1
@property
def x2(self):
return self.__x2
@property
def y1(self):
return self.__y1
@property
def y2(self):
return self.__y2
@property
def w(self):
return self.__w
@property
def h(self):
return self.__h
def __str__(self):
return dumps(self.json())
def __repr__(self):
return str(self.json())
return str(self)
def __iter__(self):
return list(self.json().values()).__iter__()
@ -51,18 +65,25 @@ class Rectangle:
return hash((self.x1, self.y1, self.x2, self.y2))
@classmethod
def from_xyxy(cls, xyxy_tuple, discrete=True):
x1, y1, x2, y2 = xyxy_tuple
return cls(x1=x1, y1=y1, x2=x2, y2=y2, discrete=discrete)
def from_xyxy(cls, xyxy: Iterable[Coord], discrete=True):
"""Creates a rectangle from two points."""
return cls(*xyxy, discrete=discrete)
@classmethod
def from_xywh(cls, xywh_tuple, discrete=True):
x, y, w, h = xywh_tuple
return cls(x1=x, y1=y, w=w, h=h, discrete=discrete)
def from_xywh(cls, xywh: Iterable[Coord], discrete=True):
"""Creates a rectangle from a point and a width and height."""
x1, y1, w, h = xywh
x2 = x1 + w
y2 = y1 + h
return cls(x1, y1, x2, y2, discrete=discrete)
@classmethod
def from_dict_xywh(cls, xywh_dict, discrete=True):
return cls(x1=xywh_dict["x"], y1=xywh_dict["y"], w=xywh_dict["width"], h=xywh_dict["height"], discrete=discrete)
def from_dict_xywh(cls, xywh: Dict[str, Coord], discrete=True):
"""Creates a rectangle from a point and a width and height."""
x1, y1, w, h = itemgetter("x", "y", "width", "height")(xywh)
x2 = x1 + w
y2 = y1 + h
return cls(x1, y1, x2, y2, discrete=discrete)
def xyxy(self):
return self.x1, self.y1, self.x2, self.y2
@ -83,19 +104,25 @@ class Rectangle:
return {"x": self.x1, "y": self.y1, "width": self.w, "height": self.h}
def intersection(self, other):
"""Calculates the intersection of this and another rectangle."""
return intersection(self, other)
def area(self):
"""Calculates the area of this rectangle."""
return area(self)
def iou(self, other: Rectangle):
"""Calculates the intersection over union of this and another rectangle."""
return iou(self, other)
def includes(self, other: Rectangle, tol=3):
"""Checks if this rectangle contains another."""
return contains(self, other, tol)
def is_included(self, rectangles: Iterable[Rectangle]):
"""Checks if this rectangle is contained by any of the given rectangles."""
return is_contained(self, rectangles)
def adjacent(self, other: Rectangle, tolerance=7):
"""Checks if this rectangle is adjacent to another."""
return adjacent(self, other, tolerance)