Refactoring
Further clean up Rectangle class
This commit is contained in:
parent
8c6b940364
commit
4be91de036
@ -2,44 +2,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from json import dumps
|
||||
from typing import Iterable
|
||||
from operator import itemgetter
|
||||
from typing import Iterable, Union, Dict
|
||||
|
||||
import numpy as np
|
||||
from funcy import identity
|
||||
|
||||
from cv_analysis.utils.spacial import adjacent, contains, intersection, iou, area, is_contained
|
||||
|
||||
Coord = Union[int, float]
|
||||
|
||||
|
||||
class Rectangle:
|
||||
def __init__(
|
||||
self,
|
||||
x1=None,
|
||||
y1=None,
|
||||
w=None,
|
||||
h=None,
|
||||
x2=None,
|
||||
y2=None,
|
||||
discrete=True,
|
||||
):
|
||||
def __init__(self, x1, y1, x2, y2, discrete=True):
|
||||
"""Creates a rectangle from two points."""
|
||||
nearest_valid = int if discrete else identity
|
||||
|
||||
try:
|
||||
self.x1 = nearest_valid(x1)
|
||||
self.y1 = nearest_valid(y1)
|
||||
self.w = nearest_valid(w) if w else nearest_valid(x2 - x1)
|
||||
self.h = nearest_valid(h) if h else nearest_valid(y2 - y1)
|
||||
self.x2 = nearest_valid(x2) if x2 else self.x1 + self.w
|
||||
self.y2 = nearest_valid(y2) if y2 else self.y1 + self.h
|
||||
assert np.isclose(self.x1 + self.w, self.x2)
|
||||
assert np.isclose(self.y1 + self.h, self.y2)
|
||||
except Exception as err:
|
||||
raise ValueError("x1, y1, (w|x2), and (h|y2) must be defined.") from err
|
||||
self.__x1 = nearest_valid(x1)
|
||||
self.__y1 = nearest_valid(y1)
|
||||
self.__x2 = nearest_valid(x2)
|
||||
self.__y2 = nearest_valid(y2)
|
||||
|
||||
self.__w = nearest_valid(x2 - x1)
|
||||
self.__h = nearest_valid(y2 - y1)
|
||||
|
||||
@property
|
||||
def x1(self):
|
||||
return self.__x1
|
||||
|
||||
@property
|
||||
def x2(self):
|
||||
return self.__x2
|
||||
|
||||
@property
|
||||
def y1(self):
|
||||
return self.__y1
|
||||
|
||||
@property
|
||||
def y2(self):
|
||||
return self.__y2
|
||||
|
||||
@property
|
||||
def w(self):
|
||||
return self.__w
|
||||
|
||||
@property
|
||||
def h(self):
|
||||
return self.__h
|
||||
|
||||
def __str__(self):
|
||||
return dumps(self.json())
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.json())
|
||||
return str(self)
|
||||
|
||||
def __iter__(self):
|
||||
return list(self.json().values()).__iter__()
|
||||
@ -51,18 +65,25 @@ class Rectangle:
|
||||
return hash((self.x1, self.y1, self.x2, self.y2))
|
||||
|
||||
@classmethod
|
||||
def from_xyxy(cls, xyxy_tuple, discrete=True):
|
||||
x1, y1, x2, y2 = xyxy_tuple
|
||||
return cls(x1=x1, y1=y1, x2=x2, y2=y2, discrete=discrete)
|
||||
def from_xyxy(cls, xyxy: Iterable[Coord], discrete=True):
|
||||
"""Creates a rectangle from two points."""
|
||||
return cls(*xyxy, discrete=discrete)
|
||||
|
||||
@classmethod
|
||||
def from_xywh(cls, xywh_tuple, discrete=True):
|
||||
x, y, w, h = xywh_tuple
|
||||
return cls(x1=x, y1=y, w=w, h=h, discrete=discrete)
|
||||
def from_xywh(cls, xywh: Iterable[Coord], discrete=True):
|
||||
"""Creates a rectangle from a point and a width and height."""
|
||||
x1, y1, w, h = xywh
|
||||
x2 = x1 + w
|
||||
y2 = y1 + h
|
||||
return cls(x1, y1, x2, y2, discrete=discrete)
|
||||
|
||||
@classmethod
|
||||
def from_dict_xywh(cls, xywh_dict, discrete=True):
|
||||
return cls(x1=xywh_dict["x"], y1=xywh_dict["y"], w=xywh_dict["width"], h=xywh_dict["height"], discrete=discrete)
|
||||
def from_dict_xywh(cls, xywh: Dict[str, Coord], discrete=True):
|
||||
"""Creates a rectangle from a point and a width and height."""
|
||||
x1, y1, w, h = itemgetter("x", "y", "width", "height")(xywh)
|
||||
x2 = x1 + w
|
||||
y2 = y1 + h
|
||||
return cls(x1, y1, x2, y2, discrete=discrete)
|
||||
|
||||
def xyxy(self):
|
||||
return self.x1, self.y1, self.x2, self.y2
|
||||
@ -83,19 +104,25 @@ class Rectangle:
|
||||
return {"x": self.x1, "y": self.y1, "width": self.w, "height": self.h}
|
||||
|
||||
def intersection(self, other):
|
||||
"""Calculates the intersection of this and another rectangle."""
|
||||
return intersection(self, other)
|
||||
|
||||
def area(self):
|
||||
"""Calculates the area of this rectangle."""
|
||||
return area(self)
|
||||
|
||||
def iou(self, other: Rectangle):
|
||||
"""Calculates the intersection over union of this and another rectangle."""
|
||||
return iou(self, other)
|
||||
|
||||
def includes(self, other: Rectangle, tol=3):
|
||||
"""Checks if this rectangle contains another."""
|
||||
return contains(self, other, tol)
|
||||
|
||||
def is_included(self, rectangles: Iterable[Rectangle]):
|
||||
"""Checks if this rectangle is contained by any of the given rectangles."""
|
||||
return is_contained(self, rectangles)
|
||||
|
||||
def adjacent(self, other: Rectangle, tolerance=7):
|
||||
"""Checks if this rectangle is adjacent to another."""
|
||||
return adjacent(self, other, tolerance)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user