Refactoring

Clean up and prove correctness of intersection computation
This commit is contained in:
Matthias Bisping 2023-01-04 12:04:57 +01:00
parent b592497b75
commit 47e657aaa3
2 changed files with 47 additions and 10 deletions

View File

@ -1,6 +1,7 @@
# See https://stackoverflow.com/a/39757388/3578468
from __future__ import annotations
from operator import attrgetter
from typing import TYPE_CHECKING, Iterable
from funcy import juxt, rpartial
@ -109,3 +110,43 @@ def is_contained(rectangle: Rectangle, rectangles: Iterable[Rectangle]):
"""Check if the rectangle is contained within any of the other rectangles."""
other_rectangles = filter(lambda r: r != rectangle, rectangles)
return any(map(rpartial(contains, rectangle), other_rectangles))
def intersection(alpha, beta):
return intersection_along_x_axis(alpha, beta) * intersection_along_y_axis(alpha, beta)
def intersection_along_x_axis(alpha, beta):
return intersection_along_axis(alpha, beta, "x")
def intersection_along_y_axis(alpha, beta):
return intersection_along_axis(alpha, beta, "y")
def intersection_along_axis(alpha, beta, axis):
assert axis in ["x", "y"]
def get_component_accessor(component):
return attrgetter(f"{axis}{component}")
c1 = get_component_accessor(1)
c2 = get_component_accessor(2)
# Cases:
# a b
# [-----] (---) => [a1, b1, a2, b2] => max(0, (a2 - b1)) = 0
# b a
# (---) [-----] => [b1, a1, b2, a2] => max(0, (b2 - a1)) = 0
# a b
# [--(----]----) => [a1, b1, a2, b2] => max(0, (a2 - b1)) = (a2 - b1)
# a b
# (-[---]----) => [b1, a1, a2, b2] => max(0, (a2 - a1)) = (a2 - a1)
# b a
# [-(---)----] => [a1, b1, b2, a2] => max(0, (b2 - b1)) = (b2 - b1)
# b a
# (----[--)----] => [b1, a1, b2, a2] => max(0, (b2 - a1)) = (b2 - a1)
coords = [*sorted([c1(alpha), c1(beta)]), *sorted([c2(alpha), c2(beta)])]
intersection = max(0, coords[2] - coords[1])
return intersection

View File

@ -4,7 +4,7 @@ from typing import Iterable
import numpy as np
from funcy import identity
from cv_analysis.utils.spacial import adjacent, contains
from cv_analysis.utils.spacial import adjacent, contains, intersection
class Rectangle:
@ -52,19 +52,15 @@ class Rectangle:
def xywh(self):
return self.x1, self.y1, self.w, self.h
def intersection(self, rect):
bx1, by1, bx2, by2 = rect.xyxy()
if (self.x1 > bx2) or (bx1 > self.x2) or (self.y1 > by2) or (by1 > self.y2):
return 0
intersection_ = (min(self.x2, bx2) - max(self.x1, bx1)) * (min(self.y2, by2) - max(self.y1, by1))
return intersection_
def intersection(self, other):
return intersection(self, other)
def area(self):
return (self.x2 - self.x1) * (self.y2 - self.y1)
def iou(self, rect):
intersection = self.intersection(rect)
union = self.area() + rect.area() - intersection
def iou(self, other):
intersection = self.intersection(other)
union = self.area() + other.area() - intersection
return intersection / union
def includes(self, other: "Rectangle", tol=3):