Refactoring

This commit is contained in:
Matthias Bisping 2023-01-04 10:58:24 +01:00
parent c0d961bc39
commit b592497b75
3 changed files with 118 additions and 101 deletions

View File

@ -1,7 +1,7 @@
from collections import namedtuple
from functools import partial from functools import partial
from itertools import starmap, compress from itertools import starmap, compress
from typing import Iterable, List from typing import Iterable, List
from cv_analysis.utils.structures import Rectangle from cv_analysis.utils.structures import Rectangle

View File

@ -0,0 +1,111 @@
# See https://stackoverflow.com/a/39757388/3578468
from __future__ import annotations
from typing import TYPE_CHECKING, Iterable
from funcy import juxt, rpartial
if TYPE_CHECKING:
from cv_analysis.utils.structures import Rectangle
def adjacent(alpha: Rectangle, beta: Rectangle, tolerance=7):
"""Check if the two rectangles are adjacent to each other."""
return any(
juxt(
# +---+
# | | +---+
# | a | | b |
# | | +___+
# +___+
right_left_aligned_and_vertically_overlapping,
# +---+
# +---+ | |
# | b | | a |
# +___+ | |
# +___+
left_right_aligned_and_vertically_overlapping,
# +-----------+
# | a |
# +___________+
# +-----+
# | b |
# +_____+
bottom_top_aligned_and_horizontally_overlapping,
# +-----+
# | b |
# +_____+
# +-----------+
# | a |
# +___________+
top_bottom_aligned_and_horizontally_overlapping,
)(alpha, beta, tolerance)
)
def right_left_aligned_and_vertically_overlapping(alpha: Rectangle, beta: Rectangle, tol):
"""Check if the first rectangle is left of the other within a tolerance and also overlaps the other's y range."""
return adjacent_along_one_axis_and_overlapping_along_perpendicular_axis(
alpha.x2, beta.x1, beta.y1, beta.y2, alpha.y1, alpha.y2, tolerance=tol
)
def left_right_aligned_and_vertically_overlapping(alpha: Rectangle, beta: Rectangle, tol):
"""Check if the first rectangle is right of the other within a tolerance and also overlaps the other's y range."""
return adjacent_along_one_axis_and_overlapping_along_perpendicular_axis(
alpha.x1, beta.x2, beta.y1, beta.y2, alpha.y1, alpha.y2, tolerance=tol
)
def bottom_top_aligned_and_horizontally_overlapping(alpha: Rectangle, beta: Rectangle, tol):
"""Check if the first rectangle is above the other within a tolerance and also overlaps the other's x range."""
return adjacent_along_one_axis_and_overlapping_along_perpendicular_axis(
alpha.y2, beta.y1, beta.x1, beta.x2, alpha.x1, alpha.x2, tolerance=tol
)
def top_bottom_aligned_and_horizontally_overlapping(alpha: Rectangle, beta: Rectangle, tol):
"""Check if the first rectangle is below the other within a tolerance and also overlaps the other's x range."""
return adjacent_along_one_axis_and_overlapping_along_perpendicular_axis(
alpha.y1, beta.y2, beta.x1, beta.x2, alpha.x1, alpha.x2, tolerance=tol
)
def adjacent_along_one_axis_and_overlapping_along_perpendicular_axis(
axis_0_point_1,
axis_1_point_2,
axis_1_contained_point_1,
axis_1_contained_point_2,
axis_1_lower_bound,
axis_1_upper_bound,
tolerance,
):
"""Check if two points are adjacent along one axis and two other points overlap a range along the perpendicular
axis."""
return all(
[
abs(axis_0_point_1 - axis_1_point_2) <= tolerance,
any(
[
axis_1_lower_bound <= p <= axis_1_upper_bound
for p in [axis_1_contained_point_1, axis_1_contained_point_2]
]
),
]
)
def contains(alpha: Rectangle, beta: Rectangle, tol=3):
"""Check if the first rectangle contains the second rectangle."""
return (
beta.x1 + tol >= alpha.x1
and beta.y1 + tol >= alpha.y1
and beta.x2 - tol <= alpha.x2
and beta.y2 - tol <= alpha.y2
)
def is_contained(rectangle: Rectangle, rectangles: Iterable[Rectangle]):
"""Check if the rectangle is contained within any of the other rectangles."""
other_rectangles = filter(lambda r: r != rectangle, rectangles)
return any(map(rpartial(contains, rectangle), other_rectangles))

View File

@ -2,7 +2,9 @@ from json import dumps
from typing import Iterable from typing import Iterable
import numpy as np import numpy as np
from funcy import identity, juxt from funcy import identity
from cv_analysis.utils.spacial import adjacent, contains
class Rectangle: class Rectangle:
@ -62,27 +64,17 @@ class Rectangle:
def iou(self, rect): def iou(self, rect):
intersection = self.intersection(rect) intersection = self.intersection(rect)
if intersection == 0:
return 0
union = self.area() + rect.area() - intersection union = self.area() + rect.area() - intersection
return intersection / union return intersection / union
def includes(self, other: "Rectangle", tol=3): def includes(self, other: "Rectangle", tol=3):
"""does a include b?""" return contains(self, other, tol)
return (
other.x1 + tol >= self.x1
and other.y1 + tol >= self.y1
and other.x2 - tol <= self.x2
and other.y2 - tol <= self.y2
)
def is_included(self, rectangles: Iterable["Rectangle"]): def is_included(self, rectangles: Iterable["Rectangle"]):
return any(rect.includes(self) for rect in rectangles if not rect == self) return any(rect.includes(self) for rect in rectangles if not rect == self)
def adjacent(self, rect2: "Rectangle", tolerance=7): def adjacent(self, other: "Rectangle", tolerance=7):
if rect2 is None: return adjacent(self, other, tolerance)
return False
return adjacent(self, rect2, tolerance)
@classmethod @classmethod
def from_xyxy(cls, xyxy_tuple, discrete=True): def from_xyxy(cls, xyxy_tuple, discrete=True):
@ -109,89 +101,3 @@ class Rectangle:
def __eq__(self, rect): def __eq__(self, rect):
return all([self.x1 == rect.x1, self.y1 == rect.y1, self.w == rect.w, self.h == rect.h]) return all([self.x1 == rect.x1, self.y1 == rect.y1, self.w == rect.w, self.h == rect.h])
def adjacent(alpha: Rectangle, beta: Rectangle, tolerance=7):
"""Check if the two rectangles are adjacent to each other."""
return any(
juxt(
# +---+
# | | +---+
# | a | | b |
# | | +___+
# +___+
alpha_is_left_of_beta_within_tolerance_and_beta_overlaps_alphas_y_range,
# +---+
# +---+ | |
# | b | | a |
# +___+ | |
# +___+
alpha_is_right_of_beta_within_tolerance_and_beta_overlaps_alphas_y_range,
# +-----------+
# | a |
# +___________+
# +-----+
# | b |
# +_____+
alpha_is_above_beta_within_tolerance_and_beta_overlaps_alphas_x_range,
# +-----+
# | b |
# +_____+
# +-----------+
# | a |
# +___________+
alpha_is_below_beta_within_tolerance_and_beta_overlaps_alphas_x_range,
)(alpha, beta, tolerance)
)
def alpha_is_left_of_beta_within_tolerance_and_beta_overlaps_alphas_y_range(alpha: Rectangle, beta: Rectangle, tol):
"""Check if the first rectangle is left of the other within a tolerance and also overlaps the other's y range."""
return adjacent_along_one_axis_and_overlapping_along_perpendicular_axis(
alpha.x2, beta.x1, beta.y1, beta.y2, alpha.y1, alpha.y2, tolerance=tol
)
def alpha_is_right_of_beta_within_tolerance_and_beta_overlaps_alphas_y_range(alpha: Rectangle, beta: Rectangle, tol):
"""Check if the first rectangle is right of the other within a tolerance and also overlaps the other's y range."""
return adjacent_along_one_axis_and_overlapping_along_perpendicular_axis(
alpha.x1, beta.x2, beta.y1, beta.y2, alpha.y1, alpha.y2, tolerance=tol
)
def alpha_is_above_beta_within_tolerance_and_beta_overlaps_alphas_x_range(alpha: Rectangle, beta: Rectangle, tol):
"""Check if the first rectangle is above the other within a tolerance and also overlaps the other's x range."""
return adjacent_along_one_axis_and_overlapping_along_perpendicular_axis(
alpha.y2, beta.y1, beta.x1, beta.x2, alpha.x1, alpha.x2, tolerance=tol
)
def alpha_is_below_beta_within_tolerance_and_beta_overlaps_alphas_x_range(alpha: Rectangle, beta: Rectangle, tol):
"""Check if the first rectangle is below the other within a tolerance and also overlaps the other's x range."""
return adjacent_along_one_axis_and_overlapping_along_perpendicular_axis(
alpha.y1, beta.y2, beta.x1, beta.x2, alpha.x1, alpha.x2, tolerance=tol
)
def adjacent_along_one_axis_and_overlapping_along_perpendicular_axis(
axis_0_point_1,
axis_1_point_2,
axis_1_contained_point_1,
axis_1_contained_point_2,
axis_1_lower_bound,
axis_1_upper_bound,
tolerance,
):
"""Check if two points are adjacent along one axis and two other points overlap a range along the perpendicular
axis."""
return all(
[
abs(axis_0_point_1 - axis_1_point_2) <= tolerance,
any(
[
axis_1_lower_bound <= p <= axis_1_upper_bound
for p in [axis_1_contained_point_1, axis_1_contained_point_2]
]
),
]
)