diff --git a/fb_detr/utils/box_merging.py b/fb_detr/utils/box_merging.py index 7a10d3c..2b7ce76 100644 --- a/fb_detr/utils/box_merging.py +++ b/fb_detr/utils/box_merging.py @@ -80,7 +80,7 @@ LPBox = namedtuple('LPBox', 'label proba box') # return boxes, probas -def keep(a, b, iou_thresh): +def remove(a, b, iou_thresh): iou = compute_iou(a.box, b.box) print("iou", iou) @@ -90,7 +90,7 @@ def keep(a, b, iou_thresh): return [a, b][max_proba_box_idx], None else: print("both") - return a, b + return None, None def filter_contained(lpboxes, iou_thresh=.1): @@ -102,13 +102,12 @@ def filter_contained(lpboxes, iou_thresh=.1): print() print("current_boxes", len(current_boxes)) for a, b in combinations(current_boxes, r=2): - for keeping in filter(truth, keep(a, b, iou_thresh=iou_thresh)): + for keeping in filter(truth, remove(a, b, iou_thresh=iou_thresh)): remaining.add(keeping) try: current_boxes.remove(keeping) except: pass - break print("remaining", len(remaining)) if len(remaining) == len(current_boxes):