57 lines
2.2 KiB
Python
57 lines
2.2 KiB
Python
from functools import reduce
|
|
from operator import itemgetter
|
|
from typing import Iterable
|
|
|
|
import numpy as np
|
|
from funcy import lmap, lpluck, first
|
|
|
|
from cv_analysis.utils import lift
|
|
from cv_analysis.utils.rectangle import Rectangle
|
|
|
|
|
|
def compute_document_score(result_dict, ground_truth_dicts):
|
|
|
|
extract_cells = lambda dicts: lpluck("cells", dicts["pages"])
|
|
|
|
cells_per_ground_truth_page, cells_per_result_page = map(extract_cells, (ground_truth_dicts, result_dict))
|
|
cells_on_page_to_rectangles = lift(rectangle_from_dict)
|
|
cells_on_pages_to_rectangles = lift(cells_on_page_to_rectangles)
|
|
|
|
rectangles_per_ground_truth_page, rectangles_per_result_page = map(
|
|
cells_on_pages_to_rectangles, (cells_per_ground_truth_page, cells_per_result_page)
|
|
)
|
|
|
|
scores = lmap(compute_page_iou, rectangles_per_result_page, rectangles_per_ground_truth_page)
|
|
|
|
n_cells_per_page = np.array(lmap(len, cells_per_ground_truth_page))
|
|
document_score = np.average(scores, weights=n_cells_per_page / n_cells_per_page.sum())
|
|
|
|
return document_score
|
|
|
|
|
|
def rectangle_from_dict(d):
|
|
x1, y1, w, h = itemgetter("x", "y", "width", "height")(d)
|
|
return Rectangle(x1, y1, x1 + w, y1 + h)
|
|
|
|
|
|
def compute_page_iou(predicted_rectangles: Iterable[Rectangle], true_rectangles: Iterable[Rectangle]):
|
|
def find_best_iou(sum_so_far_and_candidate_rectangles, true_rectangle):
|
|
sum_so_far, predicted_rectangles = sum_so_far_and_candidate_rectangles
|
|
best_match, best_iou = find_max_overlap(true_rectangle, predicted_rectangles)
|
|
return sum_so_far + best_iou, predicted_rectangles - {best_match}
|
|
|
|
predicted_rectangles = set(predicted_rectangles)
|
|
true_rectangles = set(true_rectangles)
|
|
|
|
iou_sum = first(reduce(find_best_iou, true_rectangles, (0, predicted_rectangles)))
|
|
normalizing_factor = 1 / max(len(predicted_rectangles), len(true_rectangles))
|
|
score = normalizing_factor * iou_sum
|
|
|
|
return score
|
|
|
|
|
|
def find_max_overlap(rectangle: Rectangle, candidate_rectangles: Iterable[Rectangle]):
|
|
best_candidate_rectangle = max(candidate_rectangles, key=rectangle.iou)
|
|
iou = rectangle.iou(best_candidate_rectangle)
|
|
return best_candidate_rectangle, iou
|