non max supprs WIP
This commit is contained in:
parent
2f1ec100b2
commit
289848871c
@ -9,6 +9,7 @@ from detr.models import build_model
|
||||
from detr.test import get_args_parser, infer
|
||||
from iteration_utilities import starfilter
|
||||
|
||||
from fb_detr.utils.box_merging import predictions_to_lpboxes
|
||||
from fb_detr.utils.config import read_config
|
||||
|
||||
|
||||
@ -62,31 +63,38 @@ class Predictor:
|
||||
else:
|
||||
return classes.tolist()
|
||||
|
||||
def __format_prediction(self, output: dict):
|
||||
@staticmethod
|
||||
def __format_probas(probas):
|
||||
return probas.max( axis=1).tolist()
|
||||
|
||||
boxes, classes = itemgetter("bboxes", "classes")(output)
|
||||
def __format_prediction(self, predictions: dict):
|
||||
|
||||
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
|
||||
|
||||
if len(boxes):
|
||||
boxes = self.__format_boxes(boxes)
|
||||
classes = self.__format_classes(classes)
|
||||
probas = self.__format_probas(probas)
|
||||
else:
|
||||
boxes, classes = [], []
|
||||
boxes, classes, probas = [], [], []
|
||||
|
||||
output["bboxes"] = boxes
|
||||
output["classes"] = classes
|
||||
predictions["bboxes"] = boxes
|
||||
predictions["classes"] = classes
|
||||
predictions["probas"] = probas
|
||||
|
||||
return output
|
||||
return predictions
|
||||
|
||||
def __filter_predictions_for_image(self, predictions):
|
||||
|
||||
boxes, classes = itemgetter("bboxes", "classes")(predictions)
|
||||
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
|
||||
|
||||
if boxes:
|
||||
keep = map(lambda c: c != self.rejection_class, classes)
|
||||
compressed = list(compress(zip(boxes, classes), keep))
|
||||
boxes, classes = map(list, zip(*compressed)) if compressed else ([], [])
|
||||
compressed = list(compress(zip(boxes, classes, probas), keep))
|
||||
boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], [])
|
||||
predictions["bboxes"] = boxes
|
||||
predictions["classes"] = classes
|
||||
predictions["probas"] = probas
|
||||
|
||||
return predictions
|
||||
|
||||
@ -106,16 +114,22 @@ class Predictor:
|
||||
def format_predictions(self, outputs: Iterable):
|
||||
return map(self.__format_prediction, outputs)
|
||||
|
||||
def predict(self, images, threshold=None, format_output=False):
|
||||
def __merge_boxes(self, predictions):
|
||||
predictions = predictions_to_lpboxes(predictions)
|
||||
return predictions
|
||||
|
||||
def predict(self, images, threshold=None):
|
||||
|
||||
if not threshold:
|
||||
threshold = read_config("threshold")
|
||||
|
||||
predictions = infer(images, self.model, read_config("device"), threshold)
|
||||
predictions = self.format_predictions(predictions)
|
||||
if self.rejection_class:
|
||||
predictions = self.filter_predictions(predictions)
|
||||
|
||||
if format_output:
|
||||
predictions = self.format_predictions(predictions)
|
||||
if self.rejection_class:
|
||||
predictions = self.filter_predictions(predictions)
|
||||
predictions = self.__merge_boxes(predictions)
|
||||
|
||||
predictions = list(predictions)
|
||||
|
||||
return predictions
|
||||
|
||||
161
fb_detr/utils/box_merging.py
Normal file
161
fb_detr/utils/box_merging.py
Normal file
@ -0,0 +1,161 @@
|
||||
from collections import namedtuple
|
||||
from itertools import starmap, combinations
|
||||
from operator import attrgetter, itemgetter, truth
|
||||
from frozendict import frozendict
|
||||
|
||||
import numpy as np
|
||||
|
||||
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
|
||||
|
||||
def compute_intersection(a, b): # returns None if rectangles don't intersect
|
||||
|
||||
a = Rectangle(*a.values())
|
||||
b = Rectangle(*b.values())
|
||||
|
||||
dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin)
|
||||
dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin)
|
||||
|
||||
intrs = dx*dy if (dx>=0) and (dy>=0) else 0
|
||||
print("intrs", intrs)
|
||||
return intrs
|
||||
|
||||
|
||||
def compute_union(a, b):
|
||||
def area(box):
|
||||
r = Rectangle(*box.values())
|
||||
return (r.xmax - r.xmin) * (r.ymax - r.ymin)
|
||||
|
||||
return area(a) + area(b)
|
||||
|
||||
|
||||
def compute_iou(a, b):
|
||||
return compute_intersection(a, b) / compute_union(a, b)
|
||||
|
||||
|
||||
LPBox = namedtuple('LPBox', 'label proba box')
|
||||
|
||||
|
||||
# def filter_contained(boxes, probas, iou_thresh=.9):
|
||||
#
|
||||
# def make_box_proba_pair(box, proba):
|
||||
# return BoxProba(box.cpu().detach(), proba)
|
||||
#
|
||||
# current_boxes = set(starmap(make_box_proba_pair, zip(boxes, probas)))
|
||||
# print(current_boxes)
|
||||
#
|
||||
#
|
||||
# while True:
|
||||
# print(len(current_boxes))
|
||||
# remaining_boxes = set()
|
||||
# for ap, bp in combinations(current_boxes, r=2):
|
||||
# a = ap.box
|
||||
# b = bp.box
|
||||
# if iou(a, b) > iou_thresh:
|
||||
# remaining_boxes.add(ap)
|
||||
# else:
|
||||
# remaining_boxes |= {ap, bp}
|
||||
#
|
||||
# if len(remaining_boxes) == len(current_boxes):
|
||||
# break
|
||||
# else:
|
||||
# current_boxes = remaining_boxes.copy()
|
||||
#
|
||||
# return current_boxes
|
||||
|
||||
|
||||
# def filter_boxes(image, outputs, threshold=0.3):
|
||||
# # keep only predictions with confidence >= threshold
|
||||
# probas = outputs.logits.softmax(-1)[0, :, :-1]
|
||||
# keep = probas.max(-1).values > threshold
|
||||
#
|
||||
#
|
||||
# boxes = outputs.pred_boxes[0, keep].cpu()
|
||||
# probas = probas[keep]
|
||||
#
|
||||
# filtered_boxes = filter_contained(boxes, probas)
|
||||
#
|
||||
# boxes = list(map(attrgetter("box"), filtered_boxes))
|
||||
# probas = list(map(attrgetter("proba"), filtered_boxes))
|
||||
#
|
||||
# return boxes, probas
|
||||
|
||||
|
||||
def keep(a, b, iou_thresh):
|
||||
|
||||
iou = compute_iou(a.box, b.box)
|
||||
print("iou", iou)
|
||||
if iou > iou_thresh:
|
||||
max_proba_box_idx = np.array(list(map(attrgetter("proba"), [a, b]))).argmax()
|
||||
print("one")
|
||||
return [a, b][max_proba_box_idx], None
|
||||
else:
|
||||
print("both")
|
||||
return a, b
|
||||
|
||||
|
||||
def filter_contained(lpboxes, iou_thresh=.1):
|
||||
|
||||
current_boxes = {*lpboxes}
|
||||
|
||||
while True:
|
||||
print("current_boxes", len(current_boxes))
|
||||
remaining = set()
|
||||
for a, b in combinations(current_boxes, r=2):
|
||||
print()
|
||||
for keeping in filter(truth, keep(a, b, iou_thresh=iou_thresh)):
|
||||
remaining.add(keeping)
|
||||
|
||||
print("remaining", len(remaining))
|
||||
if len(remaining) == len(current_boxes):
|
||||
break
|
||||
|
||||
current_boxes = {*remaining}
|
||||
|
||||
return remaining
|
||||
|
||||
|
||||
def lpboxes_to_dict(lpboxes):
|
||||
|
||||
boxes = map(dict, map(attrgetter("box"), lpboxes))
|
||||
classes = map(attrgetter("label"), lpboxes)
|
||||
probas = map(attrgetter("proba"), lpboxes)
|
||||
|
||||
boxes, classes, probas = map(list, [boxes, classes, probas])
|
||||
|
||||
return {
|
||||
"boxes": boxes,
|
||||
"classes": classes,
|
||||
"probas": probas
|
||||
}
|
||||
|
||||
def page_predictions_to_lpboxes(predictions):
|
||||
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
|
||||
boxes = map(frozendict, boxes)
|
||||
lpboxes = list(starmap(LPBox, zip(classes, probas, boxes)))
|
||||
lpboxes = filter_contained(lpboxes)
|
||||
merged_predictions = lpboxes_to_dict(lpboxes)
|
||||
predictions.update(merged_predictions)
|
||||
return predictions
|
||||
|
||||
|
||||
def predictions_to_lpboxes(predictions_per_page):
|
||||
return map(page_predictions_to_lpboxes, predictions_per_page)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ def main(args):
|
||||
pdf = request.data
|
||||
|
||||
pages = pdf2image.convert_from_bytes(pdf)
|
||||
predictions = predictor.predict(pages, format_output=True)
|
||||
predictions = predictor.predict(pages)
|
||||
|
||||
return jsonify(list(predictions))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user