From 2f1ec100b2d33409e9178af8d53218b57d9bb0e2 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Thu, 3 Feb 2022 13:32:22 +0100 Subject: [PATCH 1/9] changed Flask to not listen on public IP --- scripts/client_mock.py | 5 +---- src/run_service.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/scripts/client_mock.py b/scripts/client_mock.py index e3960cf..f80506f 100644 --- a/scripts/client_mock.py +++ b/scripts/client_mock.py @@ -42,14 +42,11 @@ def parse_args(): def main(args): - response = requests.post("http://0.0.0.0:8080", data=open(args.pdf_path, "rb")) - + response = requests.post("http://127.0.0.1:5000", data=open(args.pdf_path, "rb")) response.raise_for_status() - predictions = response.json() print(json.dumps(predictions, indent=2)) - annotate(args.pdf_path, predictions) diff --git a/src/run_service.py b/src/run_service.py index 58528b2..be9ebcb 100644 --- a/src/run_service.py +++ b/src/run_service.py @@ -58,7 +58,7 @@ def main(args): predictor = initialize_predictor() - app.run(host="0.0.0.0", port=8080) + app.run(host="127.0.0.1", port=5000) if __name__ == "__main__": From 289848871caadb4438f889b8a030f30cfb64201a Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Thu, 3 Feb 2022 23:56:04 +0100 Subject: [PATCH 2/9] non max supprs WIP --- fb_detr/predictor.py | 42 ++++++--- fb_detr/utils/box_merging.py | 161 +++++++++++++++++++++++++++++++++++ src/run_service.py | 2 +- 3 files changed, 190 insertions(+), 15 deletions(-) create mode 100644 fb_detr/utils/box_merging.py diff --git a/fb_detr/predictor.py b/fb_detr/predictor.py index 8055120..26ab476 100644 --- a/fb_detr/predictor.py +++ b/fb_detr/predictor.py @@ -9,6 +9,7 @@ from detr.models import build_model from detr.test import get_args_parser, infer from iteration_utilities import starfilter +from fb_detr.utils.box_merging import predictions_to_lpboxes from fb_detr.utils.config import read_config @@ -62,31 +63,38 @@ class Predictor: else: return classes.tolist() - def __format_prediction(self, output: dict): + @staticmethod + def __format_probas(probas): + return probas.max( axis=1).tolist() - boxes, classes = itemgetter("bboxes", "classes")(output) + def __format_prediction(self, predictions: dict): + + boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) if len(boxes): boxes = self.__format_boxes(boxes) classes = self.__format_classes(classes) + probas = self.__format_probas(probas) else: - boxes, classes = [], [] + boxes, classes, probas = [], [], [] - output["bboxes"] = boxes - output["classes"] = classes + predictions["bboxes"] = boxes + predictions["classes"] = classes + predictions["probas"] = probas - return output + return predictions def __filter_predictions_for_image(self, predictions): - boxes, classes = itemgetter("bboxes", "classes")(predictions) + boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) if boxes: keep = map(lambda c: c != self.rejection_class, classes) - compressed = list(compress(zip(boxes, classes), keep)) - boxes, classes = map(list, zip(*compressed)) if compressed else ([], []) + compressed = list(compress(zip(boxes, classes, probas), keep)) + boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], []) predictions["bboxes"] = boxes predictions["classes"] = classes + predictions["probas"] = probas return predictions @@ -106,16 +114,22 @@ class Predictor: def format_predictions(self, outputs: Iterable): return map(self.__format_prediction, outputs) - def predict(self, images, threshold=None, format_output=False): + def __merge_boxes(self, predictions): + predictions = predictions_to_lpboxes(predictions) + return predictions + + def predict(self, images, threshold=None): if not threshold: threshold = read_config("threshold") predictions = infer(images, self.model, read_config("device"), threshold) + predictions = self.format_predictions(predictions) + if self.rejection_class: + predictions = self.filter_predictions(predictions) - if format_output: - predictions = self.format_predictions(predictions) - if self.rejection_class: - predictions = self.filter_predictions(predictions) + predictions = self.__merge_boxes(predictions) + + predictions = list(predictions) return predictions diff --git a/fb_detr/utils/box_merging.py b/fb_detr/utils/box_merging.py new file mode 100644 index 0000000..54c86d2 --- /dev/null +++ b/fb_detr/utils/box_merging.py @@ -0,0 +1,161 @@ +from collections import namedtuple +from itertools import starmap, combinations +from operator import attrgetter, itemgetter, truth +from frozendict import frozendict + +import numpy as np + +Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + +def compute_intersection(a, b): # returns None if rectangles don't intersect + + a = Rectangle(*a.values()) + b = Rectangle(*b.values()) + + dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin) + dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin) + + intrs = dx*dy if (dx>=0) and (dy>=0) else 0 + print("intrs", intrs) + return intrs + + +def compute_union(a, b): + def area(box): + r = Rectangle(*box.values()) + return (r.xmax - r.xmin) * (r.ymax - r.ymin) + + return area(a) + area(b) + + +def compute_iou(a, b): + return compute_intersection(a, b) / compute_union(a, b) + + +LPBox = namedtuple('LPBox', 'label proba box') + + +# def filter_contained(boxes, probas, iou_thresh=.9): +# +# def make_box_proba_pair(box, proba): +# return BoxProba(box.cpu().detach(), proba) +# +# current_boxes = set(starmap(make_box_proba_pair, zip(boxes, probas))) +# print(current_boxes) +# +# +# while True: +# print(len(current_boxes)) +# remaining_boxes = set() +# for ap, bp in combinations(current_boxes, r=2): +# a = ap.box +# b = bp.box +# if iou(a, b) > iou_thresh: +# remaining_boxes.add(ap) +# else: +# remaining_boxes |= {ap, bp} +# +# if len(remaining_boxes) == len(current_boxes): +# break +# else: +# current_boxes = remaining_boxes.copy() +# +# return current_boxes + + +# def filter_boxes(image, outputs, threshold=0.3): +# # keep only predictions with confidence >= threshold +# probas = outputs.logits.softmax(-1)[0, :, :-1] +# keep = probas.max(-1).values > threshold +# +# +# boxes = outputs.pred_boxes[0, keep].cpu() +# probas = probas[keep] +# +# filtered_boxes = filter_contained(boxes, probas) +# +# boxes = list(map(attrgetter("box"), filtered_boxes)) +# probas = list(map(attrgetter("proba"), filtered_boxes)) +# +# return boxes, probas + + +def keep(a, b, iou_thresh): + + iou = compute_iou(a.box, b.box) + print("iou", iou) + if iou > iou_thresh: + max_proba_box_idx = np.array(list(map(attrgetter("proba"), [a, b]))).argmax() + print("one") + return [a, b][max_proba_box_idx], None + else: + print("both") + return a, b + + +def filter_contained(lpboxes, iou_thresh=.1): + + current_boxes = {*lpboxes} + + while True: + print("current_boxes", len(current_boxes)) + remaining = set() + for a, b in combinations(current_boxes, r=2): + print() + for keeping in filter(truth, keep(a, b, iou_thresh=iou_thresh)): + remaining.add(keeping) + + print("remaining", len(remaining)) + if len(remaining) == len(current_boxes): + break + + current_boxes = {*remaining} + + return remaining + + +def lpboxes_to_dict(lpboxes): + + boxes = map(dict, map(attrgetter("box"), lpboxes)) + classes = map(attrgetter("label"), lpboxes) + probas = map(attrgetter("proba"), lpboxes) + + boxes, classes, probas = map(list, [boxes, classes, probas]) + + return { + "boxes": boxes, + "classes": classes, + "probas": probas + } + +def page_predictions_to_lpboxes(predictions): + boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) + boxes = map(frozendict, boxes) + lpboxes = list(starmap(LPBox, zip(classes, probas, boxes))) + lpboxes = filter_contained(lpboxes) + merged_predictions = lpboxes_to_dict(lpboxes) + predictions.update(merged_predictions) + return predictions + + +def predictions_to_lpboxes(predictions_per_page): + return map(page_predictions_to_lpboxes, predictions_per_page) + + + + + + + + + + + + + + + + + + + diff --git a/src/run_service.py b/src/run_service.py index be9ebcb..25f3720 100644 --- a/src/run_service.py +++ b/src/run_service.py @@ -47,7 +47,7 @@ def main(args): pdf = request.data pages = pdf2image.convert_from_bytes(pdf) - predictions = predictor.predict(pages, format_output=True) + predictions = predictor.predict(pages) return jsonify(list(predictions)) From d770e56a7f31a28dea635816cae3b7b75fed0e24 Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Fri, 4 Feb 2022 09:37:17 +0100 Subject: [PATCH 3/9] refactor & box dropping working but algo is faulty & drops too much WIP --- fb_detr/predictor.py | 4 ++-- fb_detr/utils/box_merging.py | 10 ++++++++-- incl/detr | 2 +- requirements.txt | 1 + 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/fb_detr/predictor.py b/fb_detr/predictor.py index 26ab476..a217278 100644 --- a/fb_detr/predictor.py +++ b/fb_detr/predictor.py @@ -65,7 +65,7 @@ class Predictor: @staticmethod def __format_probas(probas): - return probas.max( axis=1).tolist() + return probas.max(axis=1).tolist() def __format_prediction(self, predictions: dict): @@ -91,7 +91,7 @@ class Predictor: if boxes: keep = map(lambda c: c != self.rejection_class, classes) compressed = list(compress(zip(boxes, classes, probas), keep)) - boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], []) + boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], [], []) predictions["bboxes"] = boxes predictions["classes"] = classes predictions["probas"] = probas diff --git a/fb_detr/utils/box_merging.py b/fb_detr/utils/box_merging.py index 54c86d2..7a10d3c 100644 --- a/fb_detr/utils/box_merging.py +++ b/fb_detr/utils/box_merging.py @@ -96,20 +96,26 @@ def keep(a, b, iou_thresh): def filter_contained(lpboxes, iou_thresh=.1): current_boxes = {*lpboxes} + remaining = set() while True: + print() print("current_boxes", len(current_boxes)) - remaining = set() for a, b in combinations(current_boxes, r=2): - print() for keeping in filter(truth, keep(a, b, iou_thresh=iou_thresh)): remaining.add(keeping) + try: + current_boxes.remove(keeping) + except: + pass + break print("remaining", len(remaining)) if len(remaining) == len(current_boxes): break current_boxes = {*remaining} + remaining = set() return remaining diff --git a/incl/detr b/incl/detr index 7e3258c..c17cddd 160000 --- a/incl/detr +++ b/incl/detr @@ -1 +1 @@ -Subproject commit 7e3258ccc1fa2be7a9d8ab333873b79de7005809 +Subproject commit c17cddd980ae3003a2633a65744d2265228e4c71 diff --git a/requirements.txt b/requirements.txt index 7d4c102..250550b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ requests==2.27.1 iteration-utilities==0.11.0 dvc==2.9.3 dvc[ssh] +frozendict==2.3.0 From ef2bab300322da3b12326d470f1c41263779e4a0 Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Fri, 4 Feb 2022 09:58:49 +0100 Subject: [PATCH 4/9] box merging algo WIP --- fb_detr/utils/box_merging.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/fb_detr/utils/box_merging.py b/fb_detr/utils/box_merging.py index 7a10d3c..2b7ce76 100644 --- a/fb_detr/utils/box_merging.py +++ b/fb_detr/utils/box_merging.py @@ -80,7 +80,7 @@ LPBox = namedtuple('LPBox', 'label proba box') # return boxes, probas -def keep(a, b, iou_thresh): +def remove(a, b, iou_thresh): iou = compute_iou(a.box, b.box) print("iou", iou) @@ -90,7 +90,7 @@ def keep(a, b, iou_thresh): return [a, b][max_proba_box_idx], None else: print("both") - return a, b + return None, None def filter_contained(lpboxes, iou_thresh=.1): @@ -102,13 +102,12 @@ def filter_contained(lpboxes, iou_thresh=.1): print() print("current_boxes", len(current_boxes)) for a, b in combinations(current_boxes, r=2): - for keeping in filter(truth, keep(a, b, iou_thresh=iou_thresh)): + for keeping in filter(truth, remove(a, b, iou_thresh=iou_thresh)): remaining.add(keeping) try: current_boxes.remove(keeping) except: pass - break print("remaining", len(remaining)) if len(remaining) == len(current_boxes): From 381fe2dbf5d88f008d87bd807b84174376c5bcfe Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Fri, 4 Feb 2022 17:32:22 +0100 Subject: [PATCH 5/9] duplicate detection removal completed --- fb_detr/utils/box_merging.py | 107 ++++++++++------------------------- scripts/client_mock.py | 19 ++++--- 2 files changed, 42 insertions(+), 84 deletions(-) diff --git a/fb_detr/utils/box_merging.py b/fb_detr/utils/box_merging.py index 2b7ce76..29f6327 100644 --- a/fb_detr/utils/box_merging.py +++ b/fb_detr/utils/box_merging.py @@ -1,12 +1,17 @@ from collections import namedtuple from itertools import starmap, combinations -from operator import attrgetter, itemgetter, truth +from operator import attrgetter, itemgetter + from frozendict import frozendict -import numpy as np - Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + +def make_box(x1, y1, x2, y2): + keys = "x1", "y1", "x2", "y2" + return dict(zip(keys, [x1, y1, x2, y2])) + + def compute_intersection(a, b): # returns None if rectangles don't intersect a = Rectangle(*a.values()) @@ -15,9 +20,7 @@ def compute_intersection(a, b): # returns None if rectangles don't intersect dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin) dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin) - intrs = dx*dy if (dx>=0) and (dy>=0) else 0 - print("intrs", intrs) - return intrs + return dx*dy if (dx>=0) and (dy>=0) else 0 def compute_union(a, b): @@ -25,7 +28,7 @@ def compute_union(a, b): r = Rectangle(*box.values()) return (r.xmax - r.xmin) * (r.ymax - r.ymin) - return area(a) + area(b) + return (area(a) + area(b)) - compute_intersection(a, b) def compute_iou(a, b): @@ -35,88 +38,38 @@ def compute_iou(a, b): LPBox = namedtuple('LPBox', 'label proba box') -# def filter_contained(boxes, probas, iou_thresh=.9): -# -# def make_box_proba_pair(box, proba): -# return BoxProba(box.cpu().detach(), proba) -# -# current_boxes = set(starmap(make_box_proba_pair, zip(boxes, probas))) -# print(current_boxes) -# -# -# while True: -# print(len(current_boxes)) -# remaining_boxes = set() -# for ap, bp in combinations(current_boxes, r=2): -# a = ap.box -# b = bp.box -# if iou(a, b) > iou_thresh: -# remaining_boxes.add(ap) -# else: -# remaining_boxes |= {ap, bp} -# -# if len(remaining_boxes) == len(current_boxes): -# break -# else: -# current_boxes = remaining_boxes.copy() -# -# return current_boxes +def less_likely(a, b): + return min([a, b], key=attrgetter("proba")) -# def filter_boxes(image, outputs, threshold=0.3): -# # keep only predictions with confidence >= threshold -# probas = outputs.logits.softmax(-1)[0, :, :-1] -# keep = probas.max(-1).values > threshold -# -# -# boxes = outputs.pred_boxes[0, keep].cpu() -# probas = probas[keep] -# -# filtered_boxes = filter_contained(boxes, probas) -# -# boxes = list(map(attrgetter("box"), filtered_boxes)) -# probas = list(map(attrgetter("proba"), filtered_boxes)) -# -# return boxes, probas - - -def remove(a, b, iou_thresh): - +def overlap_too_much(a, b, iou_thresh): iou = compute_iou(a.box, b.box) - print("iou", iou) - if iou > iou_thresh: - max_proba_box_idx = np.array(list(map(attrgetter("proba"), [a, b]))).argmax() - print("one") - return [a, b][max_proba_box_idx], None - else: - print("both") - return None, None + return iou > iou_thresh def filter_contained(lpboxes, iou_thresh=.1): + def remove_less_likely(a, b): + try: + ll = less_likely(a, b) + current_boxes.remove(ll) + except KeyError: + pass + current_boxes = {*lpboxes} - remaining = set() while True: - print() - print("current_boxes", len(current_boxes)) + n = len(current_boxes) for a, b in combinations(current_boxes, r=2): - for keeping in filter(truth, remove(a, b, iou_thresh=iou_thresh)): - remaining.add(keeping) - try: - current_boxes.remove(keeping) - except: - pass + if len({a, b} & current_boxes) != 2: + continue + if overlap_too_much(a, b, iou_thresh): + remove_less_likely(a, b) - print("remaining", len(remaining)) - if len(remaining) == len(current_boxes): + if n == len(current_boxes): break - current_boxes = {*remaining} - remaining = set() - - return remaining + return current_boxes def lpboxes_to_dict(lpboxes): @@ -128,11 +81,12 @@ def lpboxes_to_dict(lpboxes): boxes, classes, probas = map(list, [boxes, classes, probas]) return { - "boxes": boxes, + "bboxes": boxes, "classes": classes, "probas": probas } + def page_predictions_to_lpboxes(predictions): boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) boxes = map(frozendict, boxes) @@ -140,6 +94,7 @@ def page_predictions_to_lpboxes(predictions): lpboxes = filter_contained(lpboxes) merged_predictions = lpboxes_to_dict(lpboxes) predictions.update(merged_predictions) + return predictions diff --git a/scripts/client_mock.py b/scripts/client_mock.py index f80506f..7d26000 100644 --- a/scripts/client_mock.py +++ b/scripts/client_mock.py @@ -4,20 +4,23 @@ from operator import itemgetter import pdf2image import requests -from PIL import ImageDraw +from PIL import ImageDraw, ImageFont -def draw_coco_box(draw: ImageDraw.Draw, bbox, klass): +def draw_coco_box(draw: ImageDraw.Draw, bbox, klass, proba): x1, y1, x2, y2 = itemgetter("x1", "y1", "x2", "y2")(bbox) draw.rectangle(((x1, y1), (x2, y2)), outline="red") - draw.text((x1, y1), text=klass, fill=(0, 0, 0, 100)) + + fnt = ImageFont.truetype("Pillow/Tests/fonts/FreeMono.ttf", 30) + + draw.text((x1, y2), text=f"{klass}: {proba:.2f}", fill=(0, 0, 0, 100), font=fnt) -def draw_coco_boxes(image, bboxes, classes): +def draw_coco_boxes(image, bboxes, classes, probas): draw = ImageDraw.Draw(image) - for bbox, klass in zip(bboxes, classes): - draw_coco_box(draw, bbox, klass) + for bbox, klass, proba in zip(bboxes, classes, probas): + draw_coco_box(draw, bbox, klass, proba) return image @@ -26,9 +29,9 @@ def annotate(pdf_path, predictions): pages = pdf2image.convert_from_path(pdf_path) for prd in predictions: - page_idx, boxes, classes = itemgetter("page_idx", "bboxes", "classes")(prd) + page_idx, boxes, classes, probas = itemgetter("page_idx", "bboxes", "classes", "probas")(prd) page = pages[page_idx] - image = draw_coco_boxes(page, boxes, classes) + image = draw_coco_boxes(page, boxes, classes, probas) image.save(f"/tmp/serv_out/{page_idx}.png") From d694866e1e98e6129f37eaf4c1950b962fed437f Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Fri, 4 Feb 2022 17:33:07 +0100 Subject: [PATCH 6/9] applied black --- fb_detr/utils/box_merging.py | 34 +++++----------------------------- 1 file changed, 5 insertions(+), 29 deletions(-) diff --git a/fb_detr/utils/box_merging.py b/fb_detr/utils/box_merging.py index 29f6327..1b94d11 100644 --- a/fb_detr/utils/box_merging.py +++ b/fb_detr/utils/box_merging.py @@ -4,7 +4,7 @@ from operator import attrgetter, itemgetter from frozendict import frozendict -Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') +Rectangle = namedtuple("Rectangle", "xmin ymin xmax ymax") def make_box(x1, y1, x2, y2): @@ -20,7 +20,7 @@ def compute_intersection(a, b): # returns None if rectangles don't intersect dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin) dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin) - return dx*dy if (dx>=0) and (dy>=0) else 0 + return dx * dy if (dx >= 0) and (dy >= 0) else 0 def compute_union(a, b): @@ -35,7 +35,7 @@ def compute_iou(a, b): return compute_intersection(a, b) / compute_union(a, b) -LPBox = namedtuple('LPBox', 'label proba box') +LPBox = namedtuple("LPBox", "label proba box") def less_likely(a, b): @@ -47,8 +47,7 @@ def overlap_too_much(a, b, iou_thresh): return iou > iou_thresh -def filter_contained(lpboxes, iou_thresh=.1): - +def filter_contained(lpboxes, iou_thresh=0.1): def remove_less_likely(a, b): try: ll = less_likely(a, b) @@ -80,11 +79,7 @@ def lpboxes_to_dict(lpboxes): boxes, classes, probas = map(list, [boxes, classes, probas]) - return { - "bboxes": boxes, - "classes": classes, - "probas": probas - } + return {"bboxes": boxes, "classes": classes, "probas": probas} def page_predictions_to_lpboxes(predictions): @@ -100,22 +95,3 @@ def page_predictions_to_lpboxes(predictions): def predictions_to_lpboxes(predictions_per_page): return map(page_predictions_to_lpboxes, predictions_per_page) - - - - - - - - - - - - - - - - - - - From ebc37299df598b71f7569d8e8473bdb66bbbbd1a Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Fri, 4 Feb 2022 17:34:26 +0100 Subject: [PATCH 7/9] renaming --- fb_detr/utils/{box_merging.py => non_max_supprs.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename fb_detr/utils/{box_merging.py => non_max_supprs.py} (100%) diff --git a/fb_detr/utils/box_merging.py b/fb_detr/utils/non_max_supprs.py similarity index 100% rename from fb_detr/utils/box_merging.py rename to fb_detr/utils/non_max_supprs.py From 9cc31b70e39412b3613a117228554608d947dbb5 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Fri, 4 Feb 2022 17:41:00 +0100 Subject: [PATCH 8/9] refactoring, renaming --- fb_detr/predictor.py | 4 ++-- fb_detr/utils/non_max_supprs.py | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/fb_detr/predictor.py b/fb_detr/predictor.py index a217278..b7def19 100644 --- a/fb_detr/predictor.py +++ b/fb_detr/predictor.py @@ -9,7 +9,7 @@ from detr.models import build_model from detr.test import get_args_parser, infer from iteration_utilities import starfilter -from fb_detr.utils.box_merging import predictions_to_lpboxes +from fb_detr.utils.non_max_supprs import greedy_non_max_supprs from fb_detr.utils.config import read_config @@ -115,7 +115,7 @@ class Predictor: return map(self.__format_prediction, outputs) def __merge_boxes(self, predictions): - predictions = predictions_to_lpboxes(predictions) + predictions = map(greedy_non_max_supprs, predictions) return predictions def predict(self, images, threshold=None): diff --git a/fb_detr/utils/non_max_supprs.py b/fb_detr/utils/non_max_supprs.py index 1b94d11..55811c7 100644 --- a/fb_detr/utils/non_max_supprs.py +++ b/fb_detr/utils/non_max_supprs.py @@ -47,7 +47,7 @@ def overlap_too_much(a, b, iou_thresh): return iou > iou_thresh -def filter_contained(lpboxes, iou_thresh=0.1): +def __greedy_non_max_supprs(lpboxes, iou_thresh=0.1): def remove_less_likely(a, b): try: ll = less_likely(a, b) @@ -82,16 +82,15 @@ def lpboxes_to_dict(lpboxes): return {"bboxes": boxes, "classes": classes, "probas": probas} -def page_predictions_to_lpboxes(predictions): +def greedy_non_max_supprs(predictions): + boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) boxes = map(frozendict, boxes) lpboxes = list(starmap(LPBox, zip(classes, probas, boxes))) - lpboxes = filter_contained(lpboxes) + + lpboxes = __greedy_non_max_supprs(lpboxes) + merged_predictions = lpboxes_to_dict(lpboxes) predictions.update(merged_predictions) return predictions - - -def predictions_to_lpboxes(predictions_per_page): - return map(page_predictions_to_lpboxes, predictions_per_page) From 8b308a0906bc73c62db7cc1e63413b44d5e2558c Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Fri, 4 Feb 2022 17:46:17 +0100 Subject: [PATCH 9/9] removed outdated comment --- fb_detr/utils/non_max_supprs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fb_detr/utils/non_max_supprs.py b/fb_detr/utils/non_max_supprs.py index 55811c7..f38a63e 100644 --- a/fb_detr/utils/non_max_supprs.py +++ b/fb_detr/utils/non_max_supprs.py @@ -12,7 +12,7 @@ def make_box(x1, y1, x2, y2): return dict(zip(keys, [x1, y1, x2, y2])) -def compute_intersection(a, b): # returns None if rectangles don't intersect +def compute_intersection(a, b): a = Rectangle(*a.values()) b = Rectangle(*b.values())