From 47e657aaa33a08388ba1d5060734aad359726919 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Wed, 4 Jan 2023 12:04:57 +0100 Subject: [PATCH] Refactoring Clean up and prove correctness of intersection computation --- cv_analysis/utils/spacial.py | 41 +++++++++++++++++++++++++++++++++ cv_analysis/utils/structures.py | 16 +++++-------- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/cv_analysis/utils/spacial.py b/cv_analysis/utils/spacial.py index 9504c2a..863ada0 100644 --- a/cv_analysis/utils/spacial.py +++ b/cv_analysis/utils/spacial.py @@ -1,6 +1,7 @@ # See https://stackoverflow.com/a/39757388/3578468 from __future__ import annotations +from operator import attrgetter from typing import TYPE_CHECKING, Iterable from funcy import juxt, rpartial @@ -109,3 +110,43 @@ def is_contained(rectangle: Rectangle, rectangles: Iterable[Rectangle]): """Check if the rectangle is contained within any of the other rectangles.""" other_rectangles = filter(lambda r: r != rectangle, rectangles) return any(map(rpartial(contains, rectangle), other_rectangles)) + + +def intersection(alpha, beta): + return intersection_along_x_axis(alpha, beta) * intersection_along_y_axis(alpha, beta) + + +def intersection_along_x_axis(alpha, beta): + return intersection_along_axis(alpha, beta, "x") + + +def intersection_along_y_axis(alpha, beta): + return intersection_along_axis(alpha, beta, "y") + + +def intersection_along_axis(alpha, beta, axis): + assert axis in ["x", "y"] + + def get_component_accessor(component): + return attrgetter(f"{axis}{component}") + + c1 = get_component_accessor(1) + c2 = get_component_accessor(2) + + # Cases: + # a b + # [-----] (---) => [a1, b1, a2, b2] => max(0, (a2 - b1)) = 0 + # b a + # (---) [-----] => [b1, a1, b2, a2] => max(0, (b2 - a1)) = 0 + # a b + # [--(----]----) => [a1, b1, a2, b2] => max(0, (a2 - b1)) = (a2 - b1) + # a b + # (-[---]----) => [b1, a1, a2, b2] => max(0, (a2 - a1)) = (a2 - a1) + # b a + # [-(---)----] => [a1, b1, b2, a2] => max(0, (b2 - b1)) = (b2 - b1) + # b a + # (----[--)----] => [b1, a1, b2, a2] => max(0, (b2 - a1)) = (b2 - a1) + + coords = [*sorted([c1(alpha), c1(beta)]), *sorted([c2(alpha), c2(beta)])] + intersection = max(0, coords[2] - coords[1]) + return intersection diff --git a/cv_analysis/utils/structures.py b/cv_analysis/utils/structures.py index dd3223b..89fdaa0 100644 --- a/cv_analysis/utils/structures.py +++ b/cv_analysis/utils/structures.py @@ -4,7 +4,7 @@ from typing import Iterable import numpy as np from funcy import identity -from cv_analysis.utils.spacial import adjacent, contains +from cv_analysis.utils.spacial import adjacent, contains, intersection class Rectangle: @@ -52,19 +52,15 @@ class Rectangle: def xywh(self): return self.x1, self.y1, self.w, self.h - def intersection(self, rect): - bx1, by1, bx2, by2 = rect.xyxy() - if (self.x1 > bx2) or (bx1 > self.x2) or (self.y1 > by2) or (by1 > self.y2): - return 0 - intersection_ = (min(self.x2, bx2) - max(self.x1, bx1)) * (min(self.y2, by2) - max(self.y1, by1)) - return intersection_ + def intersection(self, other): + return intersection(self, other) def area(self): return (self.x2 - self.x1) * (self.y2 - self.y1) - def iou(self, rect): - intersection = self.intersection(rect) - union = self.area() + rect.area() - intersection + def iou(self, other): + intersection = self.intersection(other) + union = self.area() + other.area() - intersection return intersection / union def includes(self, other: "Rectangle", tol=3):