Matthias Bisping 8c6b940364 Refactoring
Clean up Rectangle class
2023-01-04 14:57:47 +01:00

102 lines
3.0 KiB
Python

# See https://stackoverflow.com/a/33533514
from __future__ import annotations
from json import dumps
from typing import Iterable
import numpy as np
from funcy import identity
from cv_analysis.utils.spacial import adjacent, contains, intersection, iou, area, is_contained
class Rectangle:
def __init__(
self,
x1=None,
y1=None,
w=None,
h=None,
x2=None,
y2=None,
discrete=True,
):
nearest_valid = int if discrete else identity
try:
self.x1 = nearest_valid(x1)
self.y1 = nearest_valid(y1)
self.w = nearest_valid(w) if w else nearest_valid(x2 - x1)
self.h = nearest_valid(h) if h else nearest_valid(y2 - y1)
self.x2 = nearest_valid(x2) if x2 else self.x1 + self.w
self.y2 = nearest_valid(y2) if y2 else self.y1 + self.h
assert np.isclose(self.x1 + self.w, self.x2)
assert np.isclose(self.y1 + self.h, self.y2)
except Exception as err:
raise ValueError("x1, y1, (w|x2), and (h|y2) must be defined.") from err
def __str__(self):
return dumps(self.json())
def __repr__(self):
return str(self.json())
def __iter__(self):
return list(self.json().values()).__iter__()
def __eq__(self, other: Rectangle):
return all([self.x1 == other.x1, self.y1 == other.y1, self.w == other.w, self.h == other.h])
def __hash__(self):
return hash((self.x1, self.y1, self.x2, self.y2))
@classmethod
def from_xyxy(cls, xyxy_tuple, discrete=True):
x1, y1, x2, y2 = xyxy_tuple
return cls(x1=x1, y1=y1, x2=x2, y2=y2, discrete=discrete)
@classmethod
def from_xywh(cls, xywh_tuple, discrete=True):
x, y, w, h = xywh_tuple
return cls(x1=x, y1=y, w=w, h=h, discrete=discrete)
@classmethod
def from_dict_xywh(cls, xywh_dict, discrete=True):
return cls(x1=xywh_dict["x"], y1=xywh_dict["y"], w=xywh_dict["width"], h=xywh_dict["height"], discrete=discrete)
def xyxy(self):
return self.x1, self.y1, self.x2, self.y2
def xywh(self):
return self.x1, self.y1, self.w, self.h
def json(self):
return self.json_xywh()
def json_full(self):
return {**self.json_xyxy(), "width": self.w, "height": self.h}
def json_xyxy(self):
return {"x1": self.x1, "y1": self.y1, "x2": self.x2, "y2": self.y2}
def json_xywh(self):
return {"x": self.x1, "y": self.y1, "width": self.w, "height": self.h}
def intersection(self, other):
return intersection(self, other)
def area(self):
return area(self)
def iou(self, other: Rectangle):
return iou(self, other)
def includes(self, other: Rectangle, tol=3):
return contains(self, other, tol)
def is_included(self, rectangles: Iterable[Rectangle]):
return is_contained(self, rectangles)
def adjacent(self, other: Rectangle, tolerance=7):
return adjacent(self, other, tolerance)