from collections import namedtuple from itertools import starmap, combinations from operator import attrgetter, itemgetter, truth from frozendict import frozendict import numpy as np Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') def compute_intersection(a, b): # returns None if rectangles don't intersect a = Rectangle(*a.values()) b = Rectangle(*b.values()) dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin) dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin) intrs = dx*dy if (dx>=0) and (dy>=0) else 0 print("intrs", intrs) return intrs def compute_union(a, b): def area(box): r = Rectangle(*box.values()) return (r.xmax - r.xmin) * (r.ymax - r.ymin) return area(a) + area(b) def compute_iou(a, b): return compute_intersection(a, b) / compute_union(a, b) LPBox = namedtuple('LPBox', 'label proba box') # def filter_contained(boxes, probas, iou_thresh=.9): # # def make_box_proba_pair(box, proba): # return BoxProba(box.cpu().detach(), proba) # # current_boxes = set(starmap(make_box_proba_pair, zip(boxes, probas))) # print(current_boxes) # # # while True: # print(len(current_boxes)) # remaining_boxes = set() # for ap, bp in combinations(current_boxes, r=2): # a = ap.box # b = bp.box # if iou(a, b) > iou_thresh: # remaining_boxes.add(ap) # else: # remaining_boxes |= {ap, bp} # # if len(remaining_boxes) == len(current_boxes): # break # else: # current_boxes = remaining_boxes.copy() # # return current_boxes # def filter_boxes(image, outputs, threshold=0.3): # # keep only predictions with confidence >= threshold # probas = outputs.logits.softmax(-1)[0, :, :-1] # keep = probas.max(-1).values > threshold # # # boxes = outputs.pred_boxes[0, keep].cpu() # probas = probas[keep] # # filtered_boxes = filter_contained(boxes, probas) # # boxes = list(map(attrgetter("box"), filtered_boxes)) # probas = list(map(attrgetter("proba"), filtered_boxes)) # # return boxes, probas def keep(a, b, iou_thresh): iou = compute_iou(a.box, b.box) print("iou", iou) if iou > iou_thresh: max_proba_box_idx = np.array(list(map(attrgetter("proba"), [a, b]))).argmax() print("one") return [a, b][max_proba_box_idx], None else: print("both") return a, b def filter_contained(lpboxes, iou_thresh=.1): current_boxes = {*lpboxes} while True: print("current_boxes", len(current_boxes)) remaining = set() for a, b in combinations(current_boxes, r=2): print() for keeping in filter(truth, keep(a, b, iou_thresh=iou_thresh)): remaining.add(keeping) print("remaining", len(remaining)) if len(remaining) == len(current_boxes): break current_boxes = {*remaining} return remaining def lpboxes_to_dict(lpboxes): boxes = map(dict, map(attrgetter("box"), lpboxes)) classes = map(attrgetter("label"), lpboxes) probas = map(attrgetter("proba"), lpboxes) boxes, classes, probas = map(list, [boxes, classes, probas]) return { "boxes": boxes, "classes": classes, "probas": probas } def page_predictions_to_lpboxes(predictions): boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) boxes = map(frozendict, boxes) lpboxes = list(starmap(LPBox, zip(classes, probas, boxes))) lpboxes = filter_contained(lpboxes) merged_predictions = lpboxes_to_dict(lpboxes) predictions.update(merged_predictions) return predictions def predictions_to_lpboxes(predictions_per_page): return map(page_predictions_to_lpboxes, predictions_per_page)