Refactoring
Clean up and prove correctness of intersection computation
This commit is contained in:
parent
b592497b75
commit
47e657aaa3
@ -1,6 +1,7 @@
|
||||
# See https://stackoverflow.com/a/39757388/3578468
|
||||
from __future__ import annotations
|
||||
|
||||
from operator import attrgetter
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from funcy import juxt, rpartial
|
||||
@ -109,3 +110,43 @@ def is_contained(rectangle: Rectangle, rectangles: Iterable[Rectangle]):
|
||||
"""Check if the rectangle is contained within any of the other rectangles."""
|
||||
other_rectangles = filter(lambda r: r != rectangle, rectangles)
|
||||
return any(map(rpartial(contains, rectangle), other_rectangles))
|
||||
|
||||
|
||||
def intersection(alpha, beta):
|
||||
return intersection_along_x_axis(alpha, beta) * intersection_along_y_axis(alpha, beta)
|
||||
|
||||
|
||||
def intersection_along_x_axis(alpha, beta):
|
||||
return intersection_along_axis(alpha, beta, "x")
|
||||
|
||||
|
||||
def intersection_along_y_axis(alpha, beta):
|
||||
return intersection_along_axis(alpha, beta, "y")
|
||||
|
||||
|
||||
def intersection_along_axis(alpha, beta, axis):
|
||||
assert axis in ["x", "y"]
|
||||
|
||||
def get_component_accessor(component):
|
||||
return attrgetter(f"{axis}{component}")
|
||||
|
||||
c1 = get_component_accessor(1)
|
||||
c2 = get_component_accessor(2)
|
||||
|
||||
# Cases:
|
||||
# a b
|
||||
# [-----] (---) => [a1, b1, a2, b2] => max(0, (a2 - b1)) = 0
|
||||
# b a
|
||||
# (---) [-----] => [b1, a1, b2, a2] => max(0, (b2 - a1)) = 0
|
||||
# a b
|
||||
# [--(----]----) => [a1, b1, a2, b2] => max(0, (a2 - b1)) = (a2 - b1)
|
||||
# a b
|
||||
# (-[---]----) => [b1, a1, a2, b2] => max(0, (a2 - a1)) = (a2 - a1)
|
||||
# b a
|
||||
# [-(---)----] => [a1, b1, b2, a2] => max(0, (b2 - b1)) = (b2 - b1)
|
||||
# b a
|
||||
# (----[--)----] => [b1, a1, b2, a2] => max(0, (b2 - a1)) = (b2 - a1)
|
||||
|
||||
coords = [*sorted([c1(alpha), c1(beta)]), *sorted([c2(alpha), c2(beta)])]
|
||||
intersection = max(0, coords[2] - coords[1])
|
||||
return intersection
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Iterable
|
||||
import numpy as np
|
||||
from funcy import identity
|
||||
|
||||
from cv_analysis.utils.spacial import adjacent, contains
|
||||
from cv_analysis.utils.spacial import adjacent, contains, intersection
|
||||
|
||||
|
||||
class Rectangle:
|
||||
@ -52,19 +52,15 @@ class Rectangle:
|
||||
def xywh(self):
|
||||
return self.x1, self.y1, self.w, self.h
|
||||
|
||||
def intersection(self, rect):
|
||||
bx1, by1, bx2, by2 = rect.xyxy()
|
||||
if (self.x1 > bx2) or (bx1 > self.x2) or (self.y1 > by2) or (by1 > self.y2):
|
||||
return 0
|
||||
intersection_ = (min(self.x2, bx2) - max(self.x1, bx1)) * (min(self.y2, by2) - max(self.y1, by1))
|
||||
return intersection_
|
||||
def intersection(self, other):
|
||||
return intersection(self, other)
|
||||
|
||||
def area(self):
|
||||
return (self.x2 - self.x1) * (self.y2 - self.y1)
|
||||
|
||||
def iou(self, rect):
|
||||
intersection = self.intersection(rect)
|
||||
union = self.area() + rect.area() - intersection
|
||||
def iou(self, other):
|
||||
intersection = self.intersection(other)
|
||||
union = self.area() + other.area() - intersection
|
||||
return intersection / union
|
||||
|
||||
def includes(self, other: "Rectangle", tol=3):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user