diff --git a/fb_detr/predictor.py b/fb_detr/predictor.py index 26ab476..a217278 100644 --- a/fb_detr/predictor.py +++ b/fb_detr/predictor.py @@ -65,7 +65,7 @@ class Predictor: @staticmethod def __format_probas(probas): - return probas.max( axis=1).tolist() + return probas.max(axis=1).tolist() def __format_prediction(self, predictions: dict): @@ -91,7 +91,7 @@ class Predictor: if boxes: keep = map(lambda c: c != self.rejection_class, classes) compressed = list(compress(zip(boxes, classes, probas), keep)) - boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], []) + boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], [], []) predictions["bboxes"] = boxes predictions["classes"] = classes predictions["probas"] = probas diff --git a/fb_detr/utils/box_merging.py b/fb_detr/utils/box_merging.py index 54c86d2..7a10d3c 100644 --- a/fb_detr/utils/box_merging.py +++ b/fb_detr/utils/box_merging.py @@ -96,20 +96,26 @@ def keep(a, b, iou_thresh): def filter_contained(lpboxes, iou_thresh=.1): current_boxes = {*lpboxes} + remaining = set() while True: + print() print("current_boxes", len(current_boxes)) - remaining = set() for a, b in combinations(current_boxes, r=2): - print() for keeping in filter(truth, keep(a, b, iou_thresh=iou_thresh)): remaining.add(keeping) + try: + current_boxes.remove(keeping) + except: + pass + break print("remaining", len(remaining)) if len(remaining) == len(current_boxes): break current_boxes = {*remaining} + remaining = set() return remaining diff --git a/incl/detr b/incl/detr index 7e3258c..c17cddd 160000 --- a/incl/detr +++ b/incl/detr @@ -1 +1 @@ -Subproject commit 7e3258ccc1fa2be7a9d8ab333873b79de7005809 +Subproject commit c17cddd980ae3003a2633a65744d2265228e4c71 diff --git a/requirements.txt b/requirements.txt index 7d4c102..250550b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ requests==2.27.1 iteration-utilities==0.11.0 dvc==2.9.3 dvc[ssh] +frozendict==2.3.0