diff --git a/cv_analysis/figure_detection/figure_detection_pipeline.py b/cv_analysis/figure_detection/figure_detection_pipeline.py index 1a374f1..0dce962 100644 --- a/cv_analysis/figure_detection/figure_detection_pipeline.py +++ b/cv_analysis/figure_detection/figure_detection_pipeline.py @@ -10,7 +10,7 @@ from cv_analysis.utils.filters import ( has_acceptable_format, is_not_too_large, ) -from cv_analysis.utils.post_processing import remove_included +from cv_analysis.utils.postprocessing import remove_included from cv_analysis.utils.structures import Rectangle @@ -23,9 +23,11 @@ def make_figure_detection_pipeline(min_area=5000, max_width_to_height_ratio=6): cnts = detect_large_coherent_structures(image) cnts = filter_cnts(cnts) - rects = remove_included(map(cv2.boundingRect, cnts)) - rectangles = map(Rectangle.from_xywh, rects) - return rectangles + rects = map(cv2.boundingRect, cnts) + rects = map(Rectangle.from_xywh, rects) + rects = remove_included(rects) + + return rects return pipeline diff --git a/cv_analysis/layout_parsing.py b/cv_analysis/layout_parsing.py index 1e6171e..d83e8a5 100644 --- a/cv_analysis/layout_parsing.py +++ b/cv_analysis/layout_parsing.py @@ -6,7 +6,7 @@ import cv2 import numpy as np from cv_analysis.utils.structures import Rectangle -from cv_analysis.utils.post_processing import ( +from cv_analysis.utils.postprocessing import ( remove_overlapping, remove_included, has_no_parent, diff --git a/cv_analysis/table_parsing.py b/cv_analysis/table_parsing.py index b601742..06a27ed 100644 --- a/cv_analysis/table_parsing.py +++ b/cv_analysis/table_parsing.py @@ -7,7 +7,7 @@ import numpy as np from funcy import lmap from cv_analysis.layout_parsing import parse_layout -from cv_analysis.utils.post_processing import xywh_to_vecs, xywh_to_vec_rect, adjacent1d +from cv_analysis.utils.postprocessing import remove_isolated # xywh_to_vecs, xywh_to_vec_rect, adjacent1d from cv_analysis.utils.structures import Rectangle from cv_analysis.utils.visual_logging import vizlogger @@ -86,29 +86,6 @@ def isolate_vertical_and_horizontal_components(img_bin): return img_bin_final -def has_table_shape(rects): - assert isinstance(rects, list) - - points = list(chain(*map(xywh_to_vecs, rects))) - brect = xywh_to_vec_rect(cv2.boundingRect(np.vstack(points))) - - rects = list(map(xywh_to_vec_rect, rects)) - - def matches_bounding_rect_corner(rect, x, y): - corresp_coords = list(zip(*map(attrgetter(x, y), [brect, rect]))) - ret = all(starmap(partial(adjacent1d, tolerance=30), corresp_coords)) - return ret - - return all( - ( - any(matches_bounding_rect_corner(r, "xmin", "ymin") for r in rects), - any(matches_bounding_rect_corner(r, "xmin", "ymax") for r in rects), - any(matches_bounding_rect_corner(r, "xmax", "ymax") for r in rects), - any(matches_bounding_rect_corner(r, "xmax", "ymin") for r in rects), - ) - ) - - def find_table_layout_boxes(image: np.array): def is_large_enough(box): (x, y, w, h) = box @@ -117,7 +94,6 @@ def find_table_layout_boxes(image: np.array): layout_boxes = parse_layout(image) a = lmap(is_large_enough, layout_boxes) - print(a) return lmap(is_large_enough, layout_boxes) @@ -127,7 +103,7 @@ def preprocess(image: np.array): return ~image -def turn_connected_components_into_rects(image): +def turn_connected_components_into_rects(image: np.array): def is_large_enough(stat): x1, y1, w, h, area = stat return area > 2000 and w > 35 and h > 25 @@ -149,9 +125,12 @@ def parse_tables(image: np.array, show=False): """ image = preprocess(image) - image = isolate_vertical_and_horizontal_components(image) - rects = turn_connected_components_into_rects(image) - - return list(map(Rectangle.from_xywh, rects)) + #print(rects, "\n\n") + rects = list(map(Rectangle.from_xywh, rects)) + #print(rects, "\n\n") + rects = remove_isolated(rects) + #print(rects, "\n\n") + + return rects diff --git a/cv_analysis/utils/post_processing.py b/cv_analysis/utils/post_processing.py deleted file mode 100644 index 1749f2d..0000000 --- a/cv_analysis/utils/post_processing.py +++ /dev/null @@ -1,140 +0,0 @@ -from collections import namedtuple -from functools import partial -from itertools import starmap, compress - - -def remove_overlapping(rectangles): - def overlap(a, b): - return compute_intersection(a, b) > 0 - - def does_not_overlap(rect, rectangles): - return not any(overlap(rect, r2) for r2 in rectangles if not rect == r2) - - rectangles = list(map(xywh_to_vec_rect, rectangles)) - rectangles = filter(partial(does_not_overlap, rectangles=rectangles), rectangles) - rectangles = map(vec_rect_to_xywh, rectangles) - return rectangles - - -def remove_included(rectangles): - def included(a, b): - return b.xmin >= a.xmin and b.ymin >= a.ymin and b.xmax <= a.xmax and b.ymax <= a.ymax - - def includes(a, b, tol=3): - """does a include b?""" - return b.xmin + tol >= a.xmin and b.ymin + tol >= a.ymin and b.xmax - tol <= a.xmax and b.ymax - tol <= a.ymax - - def is_not_included(rect, rectangles): - return not any(includes(r2, rect) for r2 in rectangles if not rect == r2) - - rectangles = list(map(xywh_to_vec_rect, rectangles)) - rectangles = filter(partial(is_not_included, rectangles=rectangles), rectangles) - rectangles = map(vec_rect_to_xywh, rectangles) - return rectangles - - -# tolerance was set too low (1) most lines are 2px wide -def adjacent1d(n, m, tolerance=4): - return abs(n - m) <= tolerance - - -Rectangle = namedtuple("Rectangle", "xmin ymin xmax ymax") - - -def adjacent(a, b): - """Two rects (v1, v2), (w1, w2) are adjacent if either of: - - the x components of v2 and w1 match and the y components of w1 or w2 are in the range of the y components of v1 and v2 - - the x components of v1 and w2 match and the y components of w1 or w2 are in the range of the y components of v1 and v2 - - the y components of v2 and w1 match and the x components of w1 or w2 are in the range of the x components of v1 and v2 - - the y components of v1 and w2 match and the x components of w1 or w2 are in the range of the x components of v1 and v2 - """ - - def adjacent2d(g, h, i, j, k, l): - # print(adjacent1d(g, h), any(k <= p <= l for p in [i, j])) - return adjacent1d(g, h) and any(k <= p <= l for p in [i, j]) - - if any(x is None for x in (a, b)): - return False - v1 = a.xmin, a.ymin - v2 = a.xmax, a.ymax - w1 = b.xmin, b.ymin - w2 = b.xmax, b.ymax - return any( - ( - adjacent2d(v2[0], w1[0], w1[1], w2[1], v1[1], v2[1]), - adjacent2d(v1[0], w2[0], w1[1], w2[1], v1[1], v2[1]), - adjacent2d(v2[1], w1[1], w1[0], w2[0], v1[0], v2[0]), - adjacent2d(v1[1], w2[1], w1[0], w2[0], v1[0], v2[0]), - ) - ) - - -# FIXME: For some reason some isolated rects remain. -def __remove_isolated_unsorted(rectangles): - def is_connected(rect, rectangles): - return any(adjacent(r2, rect) for r2 in rectangles if not rect == r2) - - rectangles = list(map(xywh_to_vec_rect, rectangles)) - rectangles = filter(partial(is_connected, rectangles=rectangles), rectangles) - rectangles = map(vec_rect_to_xywh, rectangles) - return rectangles - - -def make_box(x1, y1, x2, y2): - keys = "x1", "y1", "x2", "y2" - return dict(zip(keys, [x1, y1, x2, y2])) - - -def __remove_isolated_sorted(rectangles): - def is_connected(left, center, right): - # print(left,center,right) - return any(starmap(adjacent, [(left, center), (center, right)])) - - rectangles = list(map(xywh_to_vec_rect, rectangles)) - lefts = [None, *rectangles[:-1]] - rights = [*rectangles[1:], None] - mask = starmap(is_connected, zip(lefts, rectangles, rights)) - rectangles = compress(rectangles, mask) - rectangles = map(vec_rect_to_xywh, rectangles) - return rectangles - - -def remove_isolated(rectangles, input_sorted=False): - return (__remove_isolated_sorted if input_sorted else __remove_isolated_unsorted)(rectangles) - - -Rectangle = namedtuple("Rectangle", "xmin ymin xmax ymax") - - -def compute_intersection(a, b): - dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin) - dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin) - return dx * dy if (dx >= 0) and (dy >= 0) else 0 - - -def has_no_parent(hierarchy): - return hierarchy[-1] <= 0 - - -def xywh_to_vec_rect(rect): - v1, v2 = xywh_to_vecs(rect) - return Rectangle(*v1, *v2) - - -def vecs_to_vec_rect(rect): - v1, v2 = rect - return Rectangle(*v1, *v2) - - -def xywh_to_vecs(rect): - x1, y1, w, h = rect - x2 = x1 + w - y2 = y1 + h - return (x1, y1), (x2, y2) - - -def vec_rect_to_xywh(rect): - x, y, x2, y2 = rect - w = x2 - x - h = y2 - y - return x, y, w, h diff --git a/cv_analysis/utils/postprocessing.py b/cv_analysis/utils/postprocessing.py new file mode 100644 index 0000000..d620696 --- /dev/null +++ b/cv_analysis/utils/postprocessing.py @@ -0,0 +1,50 @@ +from collections import namedtuple +from functools import partial +from itertools import starmap, compress +from typing import Iterable +from cv_analysis.utils.structures import Rectangle + + +def remove_overlapping(rectangles: Iterable[Rectangle]) -> list[Rectangle]: + def overlap(a: Rectangle, rect2: Rectangle) -> float: + return a.intersection(rect2) > 0 + + def does_not_overlap(rect: Rectangle, rectangles: Iterable[Rectangle]) -> list: + return not any(overlap(rect, rect2) for rect2 in rectangles if not rect == rect2) + + rectangles = list(filter(partial(does_not_overlap, rectangles=rectangles), rectangles)) + return rectangles + + +def remove_included(rectangles: Iterable[Rectangle]) -> list[Rectangle]: + rectangles = list(filter(partial(Rectangle.is_not_included, rectangles=rectangles), rectangles)) + return rectangles + + +def __remove_isolated_unsorted(rectangles: Iterable[Rectangle]) -> list[Rectangle]: + def is_connected(rect: Rectangle, rectangles: Iterable[Rectangle]): + return any(rect.adjacent(rect2) for rect2 in rectangles if not rect == rect2) + + rectangles = list(filter(partial(is_connected, rectangles=list(rectangles)), rectangles)) + return rectangles + + +def __remove_isolated_sorted(rectangles: Iterable[Rectangle]) -> list[Rectangle]: + def is_connected(left, center, right): + return any([left.adjacent(center), center.adjacent(right)]) + + rectangles = list(rectangles) + lefts = [None, *rectangles[:-1]] + rights = [*rectangles[1:], None] + mask = starmap(is_connected, zip(lefts, rectangles, rights)) + + rectangles = list(compress(rectangles, mask)) + return rectangles + + +def remove_isolated(rectangles: Iterable[Rectangle], input_unsorted=True) -> list[Rectangle]: + return (__remove_isolated_unsorted if input_unsorted else __remove_isolated_sorted)(rectangles) + + +def has_no_parent(hierarchy): + return hierarchy[-1] <= 0 diff --git a/cv_analysis/utils/structures.py b/cv_analysis/utils/structures.py index 076e1a3..caf60e3 100644 --- a/cv_analysis/utils/structures.py +++ b/cv_analysis/utils/structures.py @@ -1,5 +1,6 @@ from json import dumps +from typing import Iterable import numpy as np from funcy import identity @@ -48,6 +49,55 @@ class Rectangle: def xywh(self): return self.x1, self.y1, self.w, self.h + def intersection(self, rect): + bx1, by1, bx2, by2 = rect.xyxy() + if (self.x1 > bx2) or (bx1 > self.x2) or (self.y1 > by2) or (by1 > self.y2): + return 0 + intersection_ = (min(self.x2, bx2) - max(self.x1, bx1)) * (min(self.y2, by2) - max(self.y1, by1)) + return intersection_ + + def area(self): + return (self.x2 - self.x1) * (self.y2 - self.y1) + + def iou(self, rect): + intersection = self.intersection(rect) + if intersection == 0: + return 0 + union = self.area() + rect.area() - intersection + return intersection / union + + def includes(self, rect: "Rectangle", tol=3): + """does a include b?""" + return ( + rect.x1 + tol >= self.x1 + and rect.y1 + tol >= self.y1 + and rect.x2 - tol <= self.x2 + and rect.y2 - tol <= self.y2 + ) + + def is_not_included(self, rectangles: Iterable["Rectangle"]): + return not any(self.includes(rect) for rect in rectangles if not rect == self) + + def adjacent(self, rect2: "Rectangle", tolerance=7): + # tolerance=1 was set too low; most lines are 2px wide + def adjacent2d(sixtuple): + g, h, i, j, k, l = sixtuple + return (abs(g - h) <= tolerance) and any(k <= p <= l for p in [i, j]) + + if rect2 is None: + return False + return any( + map( + adjacent2d, + [ + (self.x2, rect2.x1, rect2.y1, rect2.y2, self.y1, self.y2), + (self.x1, rect2.x2, rect2.y1, rect2.y2, self.y1, self.y2), + (self.y2, rect2.y1, rect2.x1, rect2.x2, self.x1, self.x2), + (self.y1, rect2.y2, rect2.x1, rect2.x2, self.x1, self.x2), + ], + ) + ) + @classmethod def from_xyxy(cls, xyxy_tuple, discrete=True): x1, y1, x2, y2 = xyxy_tuple @@ -58,6 +108,10 @@ class Rectangle: x, y, w, h = xywh_tuple return cls(x1=x, y1=y, w=w, h=h, discrete=discrete) + @classmethod + def from_dict_xywh(cls, xywh_dict, discrete=True): + return cls(x1=xywh_dict["x"], y1=xywh_dict["y"], w=xywh_dict["width"], h=xywh_dict["height"], discrete=discrete) + def __str__(self): return dumps(self.json(), indent=self.indent) @@ -67,6 +121,9 @@ class Rectangle: def __iter__(self): return list(self.json().values()).__iter__() + def __eq__(self, rect): + return all([self.x1 == rect.x1, self.y1 == rect.y1, self.w == rect.w, self.h == rect.h]) + class Contour: def __init__(self): diff --git a/cv_analysis/utils/test_metrics.py b/cv_analysis/utils/test_metrics.py index 8df3d00..9e46e97 100644 --- a/cv_analysis/utils/test_metrics.py +++ b/cv_analysis/utils/test_metrics.py @@ -1,66 +1,23 @@ +from typing import Iterable import numpy as np from cv_analysis.utils.structures import Rectangle -def xyxy_from_object(box_object): - try: - x1, y1, x2, y2 = box_object.xyxy() - except: - try: - x1 = box_object["x"] - y1 = box_object["y"] - x2 = x1 + box_object["width"] - y2 = y1 + box_object["height"] - except: - x1, y1, x2, y2 = box_object - return x1, y1, x2, y2 - - -def xywh_from_object(box_object): - try: - x, y, w, h = box_object.xywh() - except: - try: - x = box_object["x"] - y = box_object["y"] - w = box_object["width"] - h = box_object["height"] - except: - x, y, w, h = box_object - return x, y, w, h - - -def compute_iou_from_boxes(box1: Rectangle, box2: list): - """ - Each box of the form (x1, y1, delx, dely) - """ - ax1, ay1, aw, ah = xywh_from_object(box1) - bx1, by1, bw, bh = xywh_from_object(box2) - ax2, ay2, bx2, by2 = ax1 + aw, ay1 + ah, bx1 + bw, by1 + bh - if (ax1 > bx2) or (bx1 > ax2) or (ay1 > by2) or (by1 > ay2): - return 0 - intersection = (min(ax2, bx2) - max(ax1, bx1)) * (min(ay2, by2) - max(ay1, by1)) - area_a = (ax2 - ax1) * (ay2 - ay1) - area_b = (bx2 - bx1) * (by2 - by1) - union = area_a + area_b - intersection - return intersection / union - - -def find_max_overlap(box, box_list): - best_candidate = max(box_list, key=lambda x: compute_iou_from_boxes(box, x)) - iou = compute_iou_from_boxes(box, best_candidate) +def find_max_overlap(box: Rectangle, box_list: Iterable[Rectangle]): + best_candidate = max(box_list, key=lambda x: box.iou(x)) + iou = box.iou(best_candidate) return best_candidate, iou -def compute_page_iou(results_box_list, gt_box_list): - results = results_box_list.copy() - gt = gt_box_list.copy() - if (not results) or (not gt): +def compute_page_iou(results_boxes: Iterable[Rectangle], ground_truth_boxes: Iterable[Rectangle]): + results = list(results_boxes) + truth = list(ground_truth_boxes) + if (not results) or (not truth): return 0 iou_sum = 0 - denominator = max(len(results), len(gt)) - while gt and results: - gt_box = gt.pop() + denominator = max(len(results), len(truth)) + while results and truth: + gt_box = truth.pop() best_match, best_iou = find_max_overlap(gt_box, results) results.remove(best_match) iou_sum += best_iou @@ -75,9 +32,30 @@ def compute_document_score(results_dict, annotation_dict): scores = [] for i in range(len(annotation_dict["pages"])): - scores.append(compute_page_iou(results_dict["pages"][i]["cells"], annotation_dict["pages"][i]["cells"])) - scores = np.array(scores) + scores.append( + compute_page_iou( + map(Rectangle.from_dict_xywh, results_dict["pages"][i]["cells"]), + map(Rectangle.from_dict_xywh, annotation_dict["pages"][i]["cells"]), + ) + ) - doc_score = np.average(scores, weights=page_weights) + doc_score = np.average(np.array(scores), weights=page_weights) return doc_score + + +""" +from cv_analysis.utils.test_metrics import * + +r1 = Rectangle.from_dict_xywh({'x': 30, 'y': 40, 'width': 50, 'height': 60}) +r2 = Rectangle.from_dict_xywh({'x': 40, 'y': 30, 'width': 55, 'height': 65}) +r3 = Rectangle.from_dict_xywh({'x': 45, 'y': 35, 'width': 45, 'height': 55}) +r4 = Rectangle.from_dict_xywh({'x': 25, 'y': 45, 'width': 45, 'height': 55}) +d1 = {"pages": [{"cells": [r1.json_xywh(), r2.json_xywh()]}]} +d2 = {"pages": [{"cells": [r3.json_xywh(), r4.json_xywh()]}]} + +compute_iou_from_boxes(r1, r2) +find_max_overlap(r1, [r2, r3, r4]) +compute_page_iou([r1, r2], [r3, r4]) +compute_document_score(d1, d2) +"""