58 lines
1.9 KiB
Python
58 lines
1.9 KiB
Python
from operator import itemgetter
|
|
from typing import Iterable
|
|
|
|
import numpy as np
|
|
from funcy import lmap, lpluck
|
|
|
|
from cv_analysis.utils import lift
|
|
from cv_analysis.utils.rectangle import Rectangle
|
|
|
|
|
|
def find_max_overlap(box: Rectangle, box_list: Iterable[Rectangle]):
|
|
best_candidate = max(box_list, key=lambda x: box.iou(x))
|
|
iou = box.iou(best_candidate)
|
|
return best_candidate, iou
|
|
|
|
|
|
def compute_page_iou(results_boxes: Iterable[Rectangle], ground_truth_boxes: Iterable[Rectangle]):
|
|
results = list(results_boxes)
|
|
truth = list(ground_truth_boxes)
|
|
if (not results) or (not truth):
|
|
return 0
|
|
iou_sum = 0
|
|
denominator = max(len(results), len(truth))
|
|
while results and truth:
|
|
gt_box = truth.pop()
|
|
best_match, best_iou = find_max_overlap(gt_box, results)
|
|
results.remove(best_match)
|
|
iou_sum += best_iou
|
|
score = iou_sum / denominator
|
|
return score
|
|
|
|
|
|
def compute_document_score(result_dict, ground_truth_dicts):
|
|
|
|
extract_cells = lambda dicts: lpluck("cells", dicts["pages"])
|
|
|
|
cells_per_ground_truth_page, cells_per_result_page = map(extract_cells, (ground_truth_dicts, result_dict))
|
|
cells_on_page_to_rectangles = lift(rectangle_from_dict)
|
|
cells_on_pages_to_rectangles = lift(cells_on_page_to_rectangles)
|
|
|
|
rectangles_per_ground_truth_page, rectangles_per_result_page = map(
|
|
cells_on_pages_to_rectangles, (cells_per_ground_truth_page, cells_per_result_page)
|
|
)
|
|
|
|
scores = lmap(compute_page_iou, rectangles_per_result_page, rectangles_per_ground_truth_page)
|
|
|
|
n_cells_per_page = np.array(lmap(len, cells_per_ground_truth_page))
|
|
document_score = np.average(scores, weights=n_cells_per_page / n_cells_per_page.sum())
|
|
|
|
return document_score
|
|
|
|
|
|
def rectangle_from_dict(d):
|
|
x1, y1, w, h = itemgetter("x", "y", "width", "height")(d)
|
|
x2 = x1 + w
|
|
y2 = y1 + h
|
|
return Rectangle(x1, y1, x2, y2)
|