Refactor metrics

This commit is contained in:
Matthias Bisping 2023-01-09 16:22:52 +01:00
parent 65e9735bd9
commit a97f8def7c

View File

@ -1,8 +1,9 @@
from functools import reduce
from operator import itemgetter
from typing import Iterable
import numpy as np
from funcy import lmap, lpluck
from funcy import lmap, lpluck, first
from cv_analysis.utils import lift
from cv_analysis.utils.rectangle import Rectangle
@ -30,25 +31,26 @@ def compute_document_score(result_dict, ground_truth_dicts):
def rectangle_from_dict(d):
x1, y1, w, h = itemgetter("x", "y", "width", "height")(d)
x2 = x1 + w
y2 = y1 + h
return Rectangle(x1, y1, x2, y2)
return Rectangle(x1, y1, x1 + w, y1 + h)
def compute_page_iou(results_boxes: Iterable[Rectangle], ground_truth_boxes: Iterable[Rectangle]):
results = list(results_boxes)
truth = list(ground_truth_boxes)
def compute_page_iou(predicted_rectangles: Iterable[Rectangle], true_rectangles: Iterable[Rectangle]):
def find_best_iou(sum_so_far_and_candidate_rectangles, true_rectangle):
sum_so_far, predicted_rectangles = sum_so_far_and_candidate_rectangles
best_match, best_iou = find_max_overlap(true_rectangle, predicted_rectangles)
return sum_so_far + best_iou, predicted_rectangles - {best_match}
def find_best_iou(gt_box):
best_match, best_iou = find_max_overlap(gt_box, results)
results.remove(best_match)
return best_iou
predicted_rectangles = set(predicted_rectangles)
true_rectangles = set(true_rectangles)
iou_sum = first(reduce(find_best_iou, true_rectangles, (0, predicted_rectangles)))
normalizing_factor = 1 / max(len(predicted_rectangles), len(true_rectangles))
score = normalizing_factor * iou_sum
score = sum(map(find_best_iou, truth)) / max(len(results), len(truth))
return score
def find_max_overlap(box: Rectangle, box_list: Iterable[Rectangle]):
best_candidate = max(box_list, key=lambda x: box.iou(x))
iou = box.iou(best_candidate)
return best_candidate, iou
def find_max_overlap(rectangle: Rectangle, candidate_rectangles: Iterable[Rectangle]):
best_candidate_rectangle = max(candidate_rectangles, key=rectangle.iou)
iou = rectangle.iou(best_candidate_rectangle)
return best_candidate_rectangle, iou