diff --git a/fb_detr/utils/box_merging.py b/fb_detr/utils/box_merging.py index 2b7ce76..29f6327 100644 --- a/fb_detr/utils/box_merging.py +++ b/fb_detr/utils/box_merging.py @@ -1,12 +1,17 @@ from collections import namedtuple from itertools import starmap, combinations -from operator import attrgetter, itemgetter, truth +from operator import attrgetter, itemgetter + from frozendict import frozendict -import numpy as np - Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + +def make_box(x1, y1, x2, y2): + keys = "x1", "y1", "x2", "y2" + return dict(zip(keys, [x1, y1, x2, y2])) + + def compute_intersection(a, b): # returns None if rectangles don't intersect a = Rectangle(*a.values()) @@ -15,9 +20,7 @@ def compute_intersection(a, b): # returns None if rectangles don't intersect dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin) dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin) - intrs = dx*dy if (dx>=0) and (dy>=0) else 0 - print("intrs", intrs) - return intrs + return dx*dy if (dx>=0) and (dy>=0) else 0 def compute_union(a, b): @@ -25,7 +28,7 @@ def compute_union(a, b): r = Rectangle(*box.values()) return (r.xmax - r.xmin) * (r.ymax - r.ymin) - return area(a) + area(b) + return (area(a) + area(b)) - compute_intersection(a, b) def compute_iou(a, b): @@ -35,88 +38,38 @@ def compute_iou(a, b): LPBox = namedtuple('LPBox', 'label proba box') -# def filter_contained(boxes, probas, iou_thresh=.9): -# -# def make_box_proba_pair(box, proba): -# return BoxProba(box.cpu().detach(), proba) -# -# current_boxes = set(starmap(make_box_proba_pair, zip(boxes, probas))) -# print(current_boxes) -# -# -# while True: -# print(len(current_boxes)) -# remaining_boxes = set() -# for ap, bp in combinations(current_boxes, r=2): -# a = ap.box -# b = bp.box -# if iou(a, b) > iou_thresh: -# remaining_boxes.add(ap) -# else: -# remaining_boxes |= {ap, bp} -# -# if len(remaining_boxes) == len(current_boxes): -# break -# else: -# current_boxes = remaining_boxes.copy() -# -# return current_boxes +def less_likely(a, b): + return min([a, b], key=attrgetter("proba")) -# def filter_boxes(image, outputs, threshold=0.3): -# # keep only predictions with confidence >= threshold -# probas = outputs.logits.softmax(-1)[0, :, :-1] -# keep = probas.max(-1).values > threshold -# -# -# boxes = outputs.pred_boxes[0, keep].cpu() -# probas = probas[keep] -# -# filtered_boxes = filter_contained(boxes, probas) -# -# boxes = list(map(attrgetter("box"), filtered_boxes)) -# probas = list(map(attrgetter("proba"), filtered_boxes)) -# -# return boxes, probas - - -def remove(a, b, iou_thresh): - +def overlap_too_much(a, b, iou_thresh): iou = compute_iou(a.box, b.box) - print("iou", iou) - if iou > iou_thresh: - max_proba_box_idx = np.array(list(map(attrgetter("proba"), [a, b]))).argmax() - print("one") - return [a, b][max_proba_box_idx], None - else: - print("both") - return None, None + return iou > iou_thresh def filter_contained(lpboxes, iou_thresh=.1): + def remove_less_likely(a, b): + try: + ll = less_likely(a, b) + current_boxes.remove(ll) + except KeyError: + pass + current_boxes = {*lpboxes} - remaining = set() while True: - print() - print("current_boxes", len(current_boxes)) + n = len(current_boxes) for a, b in combinations(current_boxes, r=2): - for keeping in filter(truth, remove(a, b, iou_thresh=iou_thresh)): - remaining.add(keeping) - try: - current_boxes.remove(keeping) - except: - pass + if len({a, b} & current_boxes) != 2: + continue + if overlap_too_much(a, b, iou_thresh): + remove_less_likely(a, b) - print("remaining", len(remaining)) - if len(remaining) == len(current_boxes): + if n == len(current_boxes): break - current_boxes = {*remaining} - remaining = set() - - return remaining + return current_boxes def lpboxes_to_dict(lpboxes): @@ -128,11 +81,12 @@ def lpboxes_to_dict(lpboxes): boxes, classes, probas = map(list, [boxes, classes, probas]) return { - "boxes": boxes, + "bboxes": boxes, "classes": classes, "probas": probas } + def page_predictions_to_lpboxes(predictions): boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) boxes = map(frozendict, boxes) @@ -140,6 +94,7 @@ def page_predictions_to_lpboxes(predictions): lpboxes = filter_contained(lpboxes) merged_predictions = lpboxes_to_dict(lpboxes) predictions.update(merged_predictions) + return predictions diff --git a/scripts/client_mock.py b/scripts/client_mock.py index f80506f..7d26000 100644 --- a/scripts/client_mock.py +++ b/scripts/client_mock.py @@ -4,20 +4,23 @@ from operator import itemgetter import pdf2image import requests -from PIL import ImageDraw +from PIL import ImageDraw, ImageFont -def draw_coco_box(draw: ImageDraw.Draw, bbox, klass): +def draw_coco_box(draw: ImageDraw.Draw, bbox, klass, proba): x1, y1, x2, y2 = itemgetter("x1", "y1", "x2", "y2")(bbox) draw.rectangle(((x1, y1), (x2, y2)), outline="red") - draw.text((x1, y1), text=klass, fill=(0, 0, 0, 100)) + + fnt = ImageFont.truetype("Pillow/Tests/fonts/FreeMono.ttf", 30) + + draw.text((x1, y2), text=f"{klass}: {proba:.2f}", fill=(0, 0, 0, 100), font=fnt) -def draw_coco_boxes(image, bboxes, classes): +def draw_coco_boxes(image, bboxes, classes, probas): draw = ImageDraw.Draw(image) - for bbox, klass in zip(bboxes, classes): - draw_coco_box(draw, bbox, klass) + for bbox, klass, proba in zip(bboxes, classes, probas): + draw_coco_box(draw, bbox, klass, proba) return image @@ -26,9 +29,9 @@ def annotate(pdf_path, predictions): pages = pdf2image.convert_from_path(pdf_path) for prd in predictions: - page_idx, boxes, classes = itemgetter("page_idx", "bboxes", "classes")(prd) + page_idx, boxes, classes, probas = itemgetter("page_idx", "bboxes", "classes", "probas")(prd) page = pages[page_idx] - image = draw_coco_boxes(page, boxes, classes) + image = draw_coco_boxes(page, boxes, classes, probas) image.save(f"/tmp/serv_out/{page_idx}.png")