2022-03-01 14:17:37 +01:00

97 lines
2.3 KiB
Python

from collections import namedtuple
from itertools import starmap, combinations
from operator import attrgetter, itemgetter
from frozendict import frozendict
Rectangle = namedtuple("Rectangle", "xmin ymin xmax ymax")
def make_box(x1, y1, x2, y2):
keys = "x1", "y1", "x2", "y2"
return dict(zip(keys, [x1, y1, x2, y2]))
def compute_intersection(a, b):
a = Rectangle(*a.values())
b = Rectangle(*b.values())
dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin)
dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin)
return dx * dy if (dx >= 0) and (dy >= 0) else 0
def compute_union(a, b):
def area(box):
r = Rectangle(*box.values())
return (r.xmax - r.xmin) * (r.ymax - r.ymin)
return (area(a) + area(b)) - compute_intersection(a, b)
def compute_iou(a, b):
return compute_intersection(a, b) / compute_union(a, b)
LPBox = namedtuple("LPBox", "label proba box")
def less_likely(a, b):
return min([a, b], key=attrgetter("proba"))
def overlap_too_much(a, b, iou_thresh):
iou = compute_iou(a.box, b.box)
return iou > iou_thresh
def __greedy_non_max_supprs(lpboxes, iou_thresh=0.1):
def remove_less_likely(a, b):
try:
ll = less_likely(a, b)
current_boxes.remove(ll)
except KeyError:
pass
current_boxes = {*lpboxes}
while True:
n = len(current_boxes)
for a, b in combinations(current_boxes, r=2):
if len({a, b} & current_boxes) != 2:
continue
if overlap_too_much(a, b, iou_thresh):
remove_less_likely(a, b)
if n == len(current_boxes):
break
return current_boxes
def lpboxes_to_dict(lpboxes):
boxes = map(dict, map(attrgetter("box"), lpboxes))
classes = map(attrgetter("label"), lpboxes)
probas = map(attrgetter("proba"), lpboxes)
boxes, classes, probas = map(list, [boxes, classes, probas])
return {"bboxes": boxes, "classes": classes, "probas": probas}
def greedy_non_max_supprs(predictions):
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
boxes = map(frozendict, boxes)
lpboxes = list(starmap(LPBox, zip(classes, probas, boxes)))
lpboxes = __greedy_non_max_supprs(lpboxes)
merged_predictions = lpboxes_to_dict(lpboxes)
predictions.update(merged_predictions)
return predictions