From 8ebbe0e6a7d92d4ed1fc1314595e48917027de68 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Fri, 4 Feb 2022 17:45:55 +0100 Subject: [PATCH] Pull request #2: Non max supprs Merge in RR/fb_detr_prediction_container from non_max_supprs to master Squashed commit of the following: commit 9cc31b70e39412b3613a117228554608d947dbb5 Author: Matthias Bisping Date: Fri Feb 4 17:41:00 2022 +0100 refactoring, renaming commit ebc37299df598b71f7569d8e8473bdb66bbbbd1a Author: Matthias Bisping Date: Fri Feb 4 17:34:26 2022 +0100 renaming commit d694866e1e98e6129f37eaf4c1950b962fed437f Author: Matthias Bisping Date: Fri Feb 4 17:33:07 2022 +0100 applied black commit 381fe2dbf5d88f008d87bd807b84174376c5bcfe Author: Matthias Bisping Date: Fri Feb 4 17:32:22 2022 +0100 duplicate detection removal completed commit ef2bab300322da3b12326d470f1c41263779e4a0 Author: Julius Unverfehrt Date: Fri Feb 4 09:58:49 2022 +0100 box merging algo WIP commit d770e56a7f31a28dea635816cae3b7b75fed0e24 Author: Julius Unverfehrt Date: Fri Feb 4 09:37:17 2022 +0100 refactor & box dropping working but algo is faulty & drops too much WIP commit 289848871caadb4438f889b8a030f30cfb64201a Author: Matthias Bisping Date: Thu Feb 3 23:56:04 2022 +0100 non max supprs WIP commit 2f1ec100b2d33409e9178af8d53218b57d9bb0e2 Author: Matthias Bisping Date: Thu Feb 3 13:32:22 2022 +0100 changed Flask to not listen on public IP --- fb_detr/predictor.py | 42 ++++++++++----- fb_detr/utils/non_max_supprs.py | 96 +++++++++++++++++++++++++++++++++ incl/detr | 2 +- requirements.txt | 1 + scripts/client_mock.py | 24 ++++----- src/run_service.py | 4 +- 6 files changed, 140 insertions(+), 29 deletions(-) create mode 100644 fb_detr/utils/non_max_supprs.py diff --git a/fb_detr/predictor.py b/fb_detr/predictor.py index 8055120..b7def19 100644 --- a/fb_detr/predictor.py +++ b/fb_detr/predictor.py @@ -9,6 +9,7 @@ from detr.models import build_model from detr.test import get_args_parser, infer from iteration_utilities import starfilter +from fb_detr.utils.non_max_supprs import greedy_non_max_supprs from fb_detr.utils.config import read_config @@ -62,31 +63,38 @@ class Predictor: else: return classes.tolist() - def __format_prediction(self, output: dict): + @staticmethod + def __format_probas(probas): + return probas.max(axis=1).tolist() - boxes, classes = itemgetter("bboxes", "classes")(output) + def __format_prediction(self, predictions: dict): + + boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) if len(boxes): boxes = self.__format_boxes(boxes) classes = self.__format_classes(classes) + probas = self.__format_probas(probas) else: - boxes, classes = [], [] + boxes, classes, probas = [], [], [] - output["bboxes"] = boxes - output["classes"] = classes + predictions["bboxes"] = boxes + predictions["classes"] = classes + predictions["probas"] = probas - return output + return predictions def __filter_predictions_for_image(self, predictions): - boxes, classes = itemgetter("bboxes", "classes")(predictions) + boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) if boxes: keep = map(lambda c: c != self.rejection_class, classes) - compressed = list(compress(zip(boxes, classes), keep)) - boxes, classes = map(list, zip(*compressed)) if compressed else ([], []) + compressed = list(compress(zip(boxes, classes, probas), keep)) + boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], [], []) predictions["bboxes"] = boxes predictions["classes"] = classes + predictions["probas"] = probas return predictions @@ -106,16 +114,22 @@ class Predictor: def format_predictions(self, outputs: Iterable): return map(self.__format_prediction, outputs) - def predict(self, images, threshold=None, format_output=False): + def __merge_boxes(self, predictions): + predictions = map(greedy_non_max_supprs, predictions) + return predictions + + def predict(self, images, threshold=None): if not threshold: threshold = read_config("threshold") predictions = infer(images, self.model, read_config("device"), threshold) + predictions = self.format_predictions(predictions) + if self.rejection_class: + predictions = self.filter_predictions(predictions) - if format_output: - predictions = self.format_predictions(predictions) - if self.rejection_class: - predictions = self.filter_predictions(predictions) + predictions = self.__merge_boxes(predictions) + + predictions = list(predictions) return predictions diff --git a/fb_detr/utils/non_max_supprs.py b/fb_detr/utils/non_max_supprs.py new file mode 100644 index 0000000..55811c7 --- /dev/null +++ b/fb_detr/utils/non_max_supprs.py @@ -0,0 +1,96 @@ +from collections import namedtuple +from itertools import starmap, combinations +from operator import attrgetter, itemgetter + +from frozendict import frozendict + +Rectangle = namedtuple("Rectangle", "xmin ymin xmax ymax") + + +def make_box(x1, y1, x2, y2): + keys = "x1", "y1", "x2", "y2" + return dict(zip(keys, [x1, y1, x2, y2])) + + +def compute_intersection(a, b): # returns None if rectangles don't intersect + + a = Rectangle(*a.values()) + b = Rectangle(*b.values()) + + dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin) + dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin) + + return dx * dy if (dx >= 0) and (dy >= 0) else 0 + + +def compute_union(a, b): + def area(box): + r = Rectangle(*box.values()) + return (r.xmax - r.xmin) * (r.ymax - r.ymin) + + return (area(a) + area(b)) - compute_intersection(a, b) + + +def compute_iou(a, b): + return compute_intersection(a, b) / compute_union(a, b) + + +LPBox = namedtuple("LPBox", "label proba box") + + +def less_likely(a, b): + return min([a, b], key=attrgetter("proba")) + + +def overlap_too_much(a, b, iou_thresh): + iou = compute_iou(a.box, b.box) + return iou > iou_thresh + + +def __greedy_non_max_supprs(lpboxes, iou_thresh=0.1): + def remove_less_likely(a, b): + try: + ll = less_likely(a, b) + current_boxes.remove(ll) + except KeyError: + pass + + current_boxes = {*lpboxes} + + while True: + n = len(current_boxes) + for a, b in combinations(current_boxes, r=2): + if len({a, b} & current_boxes) != 2: + continue + if overlap_too_much(a, b, iou_thresh): + remove_less_likely(a, b) + + if n == len(current_boxes): + break + + return current_boxes + + +def lpboxes_to_dict(lpboxes): + + boxes = map(dict, map(attrgetter("box"), lpboxes)) + classes = map(attrgetter("label"), lpboxes) + probas = map(attrgetter("proba"), lpboxes) + + boxes, classes, probas = map(list, [boxes, classes, probas]) + + return {"bboxes": boxes, "classes": classes, "probas": probas} + + +def greedy_non_max_supprs(predictions): + + boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) + boxes = map(frozendict, boxes) + lpboxes = list(starmap(LPBox, zip(classes, probas, boxes))) + + lpboxes = __greedy_non_max_supprs(lpboxes) + + merged_predictions = lpboxes_to_dict(lpboxes) + predictions.update(merged_predictions) + + return predictions diff --git a/incl/detr b/incl/detr index 7e3258c..c17cddd 160000 --- a/incl/detr +++ b/incl/detr @@ -1 +1 @@ -Subproject commit 7e3258ccc1fa2be7a9d8ab333873b79de7005809 +Subproject commit c17cddd980ae3003a2633a65744d2265228e4c71 diff --git a/requirements.txt b/requirements.txt index 7d4c102..250550b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ requests==2.27.1 iteration-utilities==0.11.0 dvc==2.9.3 dvc[ssh] +frozendict==2.3.0 diff --git a/scripts/client_mock.py b/scripts/client_mock.py index e3960cf..7d26000 100644 --- a/scripts/client_mock.py +++ b/scripts/client_mock.py @@ -4,20 +4,23 @@ from operator import itemgetter import pdf2image import requests -from PIL import ImageDraw +from PIL import ImageDraw, ImageFont -def draw_coco_box(draw: ImageDraw.Draw, bbox, klass): +def draw_coco_box(draw: ImageDraw.Draw, bbox, klass, proba): x1, y1, x2, y2 = itemgetter("x1", "y1", "x2", "y2")(bbox) draw.rectangle(((x1, y1), (x2, y2)), outline="red") - draw.text((x1, y1), text=klass, fill=(0, 0, 0, 100)) + + fnt = ImageFont.truetype("Pillow/Tests/fonts/FreeMono.ttf", 30) + + draw.text((x1, y2), text=f"{klass}: {proba:.2f}", fill=(0, 0, 0, 100), font=fnt) -def draw_coco_boxes(image, bboxes, classes): +def draw_coco_boxes(image, bboxes, classes, probas): draw = ImageDraw.Draw(image) - for bbox, klass in zip(bboxes, classes): - draw_coco_box(draw, bbox, klass) + for bbox, klass, proba in zip(bboxes, classes, probas): + draw_coco_box(draw, bbox, klass, proba) return image @@ -26,9 +29,9 @@ def annotate(pdf_path, predictions): pages = pdf2image.convert_from_path(pdf_path) for prd in predictions: - page_idx, boxes, classes = itemgetter("page_idx", "bboxes", "classes")(prd) + page_idx, boxes, classes, probas = itemgetter("page_idx", "bboxes", "classes", "probas")(prd) page = pages[page_idx] - image = draw_coco_boxes(page, boxes, classes) + image = draw_coco_boxes(page, boxes, classes, probas) image.save(f"/tmp/serv_out/{page_idx}.png") @@ -42,14 +45,11 @@ def parse_args(): def main(args): - response = requests.post("http://0.0.0.0:8080", data=open(args.pdf_path, "rb")) - + response = requests.post("http://127.0.0.1:5000", data=open(args.pdf_path, "rb")) response.raise_for_status() - predictions = response.json() print(json.dumps(predictions, indent=2)) - annotate(args.pdf_path, predictions) diff --git a/src/run_service.py b/src/run_service.py index 58528b2..25f3720 100644 --- a/src/run_service.py +++ b/src/run_service.py @@ -47,7 +47,7 @@ def main(args): pdf = request.data pages = pdf2image.convert_from_bytes(pdf) - predictions = predictor.predict(pages, format_output=True) + predictions = predictor.predict(pages) return jsonify(list(predictions)) @@ -58,7 +58,7 @@ def main(args): predictor = initialize_predictor() - app.run(host="0.0.0.0", port=8080) + app.run(host="127.0.0.1", port=5000) if __name__ == "__main__":