diff --git a/cv_analysis/utils/test_metrics.py b/cv_analysis/utils/test_metrics.py index f9aab22..f6d62c2 100644 --- a/cv_analysis/utils/test_metrics.py +++ b/cv_analysis/utils/test_metrics.py @@ -8,28 +8,6 @@ from cv_analysis.utils import lift from cv_analysis.utils.rectangle import Rectangle -def find_max_overlap(box: Rectangle, box_list: Iterable[Rectangle]): - best_candidate = max(box_list, key=lambda x: box.iou(x)) - iou = box.iou(best_candidate) - return best_candidate, iou - - -def compute_page_iou(results_boxes: Iterable[Rectangle], ground_truth_boxes: Iterable[Rectangle]): - results = list(results_boxes) - truth = list(ground_truth_boxes) - if (not results) or (not truth): - return 0 - iou_sum = 0 - denominator = max(len(results), len(truth)) - while results and truth: - gt_box = truth.pop() - best_match, best_iou = find_max_overlap(gt_box, results) - results.remove(best_match) - iou_sum += best_iou - score = iou_sum / denominator - return score - - def compute_document_score(result_dict, ground_truth_dicts): extract_cells = lambda dicts: lpluck("cells", dicts["pages"]) @@ -55,3 +33,25 @@ def rectangle_from_dict(d): x2 = x1 + w y2 = y1 + h return Rectangle(x1, y1, x2, y2) + + +def compute_page_iou(results_boxes: Iterable[Rectangle], ground_truth_boxes: Iterable[Rectangle]): + results = list(results_boxes) + truth = list(ground_truth_boxes) + if (not results) or (not truth): + return 0 + iou_sum = 0 + denominator = max(len(results), len(truth)) + while results and truth: + gt_box = truth.pop() + best_match, best_iou = find_max_overlap(gt_box, results) + results.remove(best_match) + iou_sum += best_iou + score = iou_sum / denominator + return score + + +def find_max_overlap(box: Rectangle, box_list: Iterable[Rectangle]): + best_candidate = max(box_list, key=lambda x: box.iou(x)) + iou = box.iou(best_candidate) + return best_candidate, iou