from json import dumps from typing import Iterable import numpy as np from funcy import identity, juxt class Rectangle: def __init__(self, x1=None, y1=None, w=None, h=None, x2=None, y2=None, indent=4, format="xywh", discrete=True): make_discrete = int if discrete else identity try: self.x1 = make_discrete(x1) self.y1 = make_discrete(y1) self.w = make_discrete(w) if w else make_discrete(x2 - x1) self.h = make_discrete(h) if h else make_discrete(y2 - y1) self.x2 = make_discrete(x2) if x2 else self.x1 + self.w self.y2 = make_discrete(y2) if y2 else self.y1 + self.h assert np.isclose(self.x1 + self.w, self.x2) assert np.isclose(self.y1 + self.h, self.y2) self.indent = indent self.format = format except Exception as err: raise Exception("x1, y1, (w|x2), and (h|y2) must be defined.") from err def json_xywh(self): return {"x": self.x1, "y": self.y1, "width": self.w, "height": self.h} def json_xyxy(self): return {"x1": self.x1, "y1": self.y1, "x2": self.x2, "y2": self.y2} def json_full(self): # TODO: can we make all coords x0, y0 based? :) return { "x0": self.x1, "y0": self.y1, "x1": self.x2, "y1": self.y2, "width": self.w, "height": self.h, } def json(self): json_func = {"xywh": self.json_xywh, "xyxy": self.json_xyxy}.get(self.format, self.json_full) return json_func() def xyxy(self): return self.x1, self.y1, self.x2, self.y2 def xywh(self): return self.x1, self.y1, self.w, self.h def intersection(self, rect): bx1, by1, bx2, by2 = rect.xyxy() if (self.x1 > bx2) or (bx1 > self.x2) or (self.y1 > by2) or (by1 > self.y2): return 0 intersection_ = (min(self.x2, bx2) - max(self.x1, bx1)) * (min(self.y2, by2) - max(self.y1, by1)) return intersection_ def area(self): return (self.x2 - self.x1) * (self.y2 - self.y1) def iou(self, rect): intersection = self.intersection(rect) if intersection == 0: return 0 union = self.area() + rect.area() - intersection return intersection / union def includes(self, other: "Rectangle", tol=3): """does a include b?""" return ( other.x1 + tol >= self.x1 and other.y1 + tol >= self.y1 and other.x2 - tol <= self.x2 and other.y2 - tol <= self.y2 ) def is_included(self, rectangles: Iterable["Rectangle"]): return any(rect.includes(self) for rect in rectangles if not rect == self) def adjacent(self, rect2: "Rectangle", tolerance=7): if rect2 is None: return False return adjacent(self, rect2, tolerance) @classmethod def from_xyxy(cls, xyxy_tuple, discrete=True): x1, y1, x2, y2 = xyxy_tuple return cls(x1=x1, y1=y1, x2=x2, y2=y2, discrete=discrete) @classmethod def from_xywh(cls, xywh_tuple, discrete=True): x, y, w, h = xywh_tuple return cls(x1=x, y1=y, w=w, h=h, discrete=discrete) @classmethod def from_dict_xywh(cls, xywh_dict, discrete=True): return cls(x1=xywh_dict["x"], y1=xywh_dict["y"], w=xywh_dict["width"], h=xywh_dict["height"], discrete=discrete) def __str__(self): return dumps(self.json(), indent=self.indent) def __repr__(self): return str(self.json()) def __iter__(self): return list(self.json().values()).__iter__() def __eq__(self, rect): return all([self.x1 == rect.x1, self.y1 == rect.y1, self.w == rect.w, self.h == rect.h]) def adjacent(alpha: Rectangle, beta: Rectangle, tolerance=7): """Check if the two rectangles are adjacent to each other.""" return any( juxt( # +---+ # | | +---+ # | a | | b | # | | +___+ # +___+ alpha_is_left_of_beta_within_tolerance_and_beta_overlaps_alphas_y_range, # +---+ # +---+ | | # | b | | a | # +___+ | | # +___+ alpha_is_right_of_beta_within_tolerance_and_beta_overlaps_alphas_y_range, # +-----------+ # | a | # +___________+ # +-----+ # | b | # +_____+ alpha_is_above_beta_within_tolerance_and_beta_overlaps_alphas_x_range, # +-----+ # | b | # +_____+ # +-----------+ # | a | # +___________+ alpha_is_below_beta_within_tolerance_and_beta_overlaps_alphas_x_range, )(alpha, beta, tolerance) ) def alpha_is_left_of_beta_within_tolerance_and_beta_overlaps_alphas_y_range(alpha: Rectangle, beta: Rectangle, tol): """Check if the first rectangle is left of the other within a tolerance and also overlaps the other's y range.""" return adjacent_along_one_axis_and_overlapping_along_perpendicular_axis( alpha.x2, beta.x1, beta.y1, beta.y2, alpha.y1, alpha.y2, tolerance=tol ) def alpha_is_right_of_beta_within_tolerance_and_beta_overlaps_alphas_y_range(alpha: Rectangle, beta: Rectangle, tol): """Check if the first rectangle is right of the other within a tolerance and also overlaps the other's y range.""" return adjacent_along_one_axis_and_overlapping_along_perpendicular_axis( alpha.x1, beta.x2, beta.y1, beta.y2, alpha.y1, alpha.y2, tolerance=tol ) def alpha_is_above_beta_within_tolerance_and_beta_overlaps_alphas_x_range(alpha: Rectangle, beta: Rectangle, tol): """Check if the first rectangle is above the other within a tolerance and also overlaps the other's x range.""" return adjacent_along_one_axis_and_overlapping_along_perpendicular_axis( alpha.y2, beta.y1, beta.x1, beta.x2, alpha.x1, alpha.x2, tolerance=tol ) def alpha_is_below_beta_within_tolerance_and_beta_overlaps_alphas_x_range(alpha: Rectangle, beta: Rectangle, tol): """Check if the first rectangle is below the other within a tolerance and also overlaps the other's x range.""" return adjacent_along_one_axis_and_overlapping_along_perpendicular_axis( alpha.y1, beta.y2, beta.x1, beta.x2, alpha.x1, alpha.x2, tolerance=tol ) def adjacent_along_one_axis_and_overlapping_along_perpendicular_axis( axis_0_point_1, axis_1_point_2, axis_1_contained_point_1, axis_1_contained_point_2, axis_1_lower_bound, axis_1_upper_bound, tolerance, ): """Check if two points are adjacent along one axis and two other points overlap a range along the perpendicular axis.""" return all( [ abs(axis_0_point_1 - axis_1_point_2) <= tolerance, any( [ axis_1_lower_bound <= p <= axis_1_upper_bound for p in [axis_1_contained_point_1, axis_1_contained_point_2] ] ), ] )