diff --git a/cv_analysis/utils/postprocessing.py b/cv_analysis/utils/postprocessing.py index 040f4bb..2f40c48 100644 --- a/cv_analysis/utils/postprocessing.py +++ b/cv_analysis/utils/postprocessing.py @@ -9,7 +9,7 @@ def remove_overlapping(rectangles: Iterable[Rectangle]) -> List[Rectangle]: def overlap(a: Rectangle, rect2: Rectangle) -> float: return a.intersection(rect2) > 0 - def does_not_overlap(rect: Rectangle, rectangles: Iterable[Rectangle]) -> list: + def does_not_overlap(rect: Rectangle, rectangles: Iterable[Rectangle]) -> bool: return not any(overlap(rect, rect2) for rect2 in rectangles if not rect == rect2) rectangles = list(filter(partial(does_not_overlap, rectangles=rectangles), rectangles)) @@ -25,7 +25,10 @@ def __remove_isolated_unsorted(rectangles: Iterable[Rectangle]) -> List[Rectangl def is_connected(rect: Rectangle, rectangles: Iterable[Rectangle]): return any(rect.adjacent(rect2) for rect2 in rectangles if not rect == rect2) - rectangles = list(filter(partial(is_connected, rectangles=list(rectangles)), rectangles)) + if not isinstance(rectangles, list): + rectangles = list(rectangles) + + rectangles = list(filter(partial(is_connected, rectangles=rectangles), rectangles)) return rectangles