2022-02-03 23:56:04 +01:00

162 lines
3.8 KiB
Python

from collections import namedtuple
from itertools import starmap, combinations
from operator import attrgetter, itemgetter, truth
from frozendict import frozendict
import numpy as np
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
def compute_intersection(a, b): # returns None if rectangles don't intersect
a = Rectangle(*a.values())
b = Rectangle(*b.values())
dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin)
dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin)
intrs = dx*dy if (dx>=0) and (dy>=0) else 0
print("intrs", intrs)
return intrs
def compute_union(a, b):
def area(box):
r = Rectangle(*box.values())
return (r.xmax - r.xmin) * (r.ymax - r.ymin)
return area(a) + area(b)
def compute_iou(a, b):
return compute_intersection(a, b) / compute_union(a, b)
LPBox = namedtuple('LPBox', 'label proba box')
# def filter_contained(boxes, probas, iou_thresh=.9):
#
# def make_box_proba_pair(box, proba):
# return BoxProba(box.cpu().detach(), proba)
#
# current_boxes = set(starmap(make_box_proba_pair, zip(boxes, probas)))
# print(current_boxes)
#
#
# while True:
# print(len(current_boxes))
# remaining_boxes = set()
# for ap, bp in combinations(current_boxes, r=2):
# a = ap.box
# b = bp.box
# if iou(a, b) > iou_thresh:
# remaining_boxes.add(ap)
# else:
# remaining_boxes |= {ap, bp}
#
# if len(remaining_boxes) == len(current_boxes):
# break
# else:
# current_boxes = remaining_boxes.copy()
#
# return current_boxes
# def filter_boxes(image, outputs, threshold=0.3):
# # keep only predictions with confidence >= threshold
# probas = outputs.logits.softmax(-1)[0, :, :-1]
# keep = probas.max(-1).values > threshold
#
#
# boxes = outputs.pred_boxes[0, keep].cpu()
# probas = probas[keep]
#
# filtered_boxes = filter_contained(boxes, probas)
#
# boxes = list(map(attrgetter("box"), filtered_boxes))
# probas = list(map(attrgetter("proba"), filtered_boxes))
#
# return boxes, probas
def keep(a, b, iou_thresh):
iou = compute_iou(a.box, b.box)
print("iou", iou)
if iou > iou_thresh:
max_proba_box_idx = np.array(list(map(attrgetter("proba"), [a, b]))).argmax()
print("one")
return [a, b][max_proba_box_idx], None
else:
print("both")
return a, b
def filter_contained(lpboxes, iou_thresh=.1):
current_boxes = {*lpboxes}
while True:
print("current_boxes", len(current_boxes))
remaining = set()
for a, b in combinations(current_boxes, r=2):
print()
for keeping in filter(truth, keep(a, b, iou_thresh=iou_thresh)):
remaining.add(keeping)
print("remaining", len(remaining))
if len(remaining) == len(current_boxes):
break
current_boxes = {*remaining}
return remaining
def lpboxes_to_dict(lpboxes):
boxes = map(dict, map(attrgetter("box"), lpboxes))
classes = map(attrgetter("label"), lpboxes)
probas = map(attrgetter("proba"), lpboxes)
boxes, classes, probas = map(list, [boxes, classes, probas])
return {
"boxes": boxes,
"classes": classes,
"probas": probas
}
def page_predictions_to_lpboxes(predictions):
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
boxes = map(frozendict, boxes)
lpboxes = list(starmap(LPBox, zip(classes, probas, boxes)))
lpboxes = filter_contained(lpboxes)
merged_predictions = lpboxes_to_dict(lpboxes)
predictions.update(merged_predictions)
return predictions
def predictions_to_lpboxes(predictions_per_page):
return map(page_predictions_to_lpboxes, predictions_per_page)