from functools import reduce from operator import itemgetter from typing import Iterable import numpy as np from funcy import lmap, lpluck, first from cv_analysis.utils import lift from cv_analysis.utils.rectangle import Rectangle def compute_document_score(result_dict, ground_truth_dicts): extract_cells = lambda dicts: lpluck("cells", dicts["pages"]) cells_per_ground_truth_page, cells_per_result_page = map(extract_cells, (ground_truth_dicts, result_dict)) cells_on_page_to_rectangles = lift(rectangle_from_dict) cells_on_pages_to_rectangles = lift(cells_on_page_to_rectangles) rectangles_per_ground_truth_page, rectangles_per_result_page = map( cells_on_pages_to_rectangles, (cells_per_ground_truth_page, cells_per_result_page) ) scores = lmap(compute_page_iou, rectangles_per_result_page, rectangles_per_ground_truth_page) n_cells_per_page = np.array(lmap(len, cells_per_ground_truth_page)) document_score = np.average(scores, weights=n_cells_per_page / n_cells_per_page.sum()) return document_score def rectangle_from_dict(d): x1, y1, w, h = itemgetter("x", "y", "width", "height")(d) return Rectangle(x1, y1, x1 + w, y1 + h) def compute_page_iou(predicted_rectangles: Iterable[Rectangle], true_rectangles: Iterable[Rectangle]): def find_best_iou(sum_so_far_and_candidate_rectangles, true_rectangle): sum_so_far, predicted_rectangles = sum_so_far_and_candidate_rectangles best_match, best_iou = find_max_overlap(true_rectangle, predicted_rectangles) return sum_so_far + best_iou, predicted_rectangles - {best_match} predicted_rectangles = set(predicted_rectangles) true_rectangles = set(true_rectangles) iou_sum = first(reduce(find_best_iou, true_rectangles, (0, predicted_rectangles))) normalizing_factor = 1 / max(len(predicted_rectangles), len(true_rectangles)) score = normalizing_factor * iou_sum return score def find_max_overlap(rectangle: Rectangle, candidate_rectangles: Iterable[Rectangle]): best_candidate_rectangle = max(candidate_rectangles, key=rectangle.iou) iou = rectangle.iou(best_candidate_rectangle) return best_candidate_rectangle, iou