diff --git a/fb_detr/predictor.py b/fb_detr/predictor.py index 8055120..26ab476 100644 --- a/fb_detr/predictor.py +++ b/fb_detr/predictor.py @@ -9,6 +9,7 @@ from detr.models import build_model from detr.test import get_args_parser, infer from iteration_utilities import starfilter +from fb_detr.utils.box_merging import predictions_to_lpboxes from fb_detr.utils.config import read_config @@ -62,31 +63,38 @@ class Predictor: else: return classes.tolist() - def __format_prediction(self, output: dict): + @staticmethod + def __format_probas(probas): + return probas.max( axis=1).tolist() - boxes, classes = itemgetter("bboxes", "classes")(output) + def __format_prediction(self, predictions: dict): + + boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) if len(boxes): boxes = self.__format_boxes(boxes) classes = self.__format_classes(classes) + probas = self.__format_probas(probas) else: - boxes, classes = [], [] + boxes, classes, probas = [], [], [] - output["bboxes"] = boxes - output["classes"] = classes + predictions["bboxes"] = boxes + predictions["classes"] = classes + predictions["probas"] = probas - return output + return predictions def __filter_predictions_for_image(self, predictions): - boxes, classes = itemgetter("bboxes", "classes")(predictions) + boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) if boxes: keep = map(lambda c: c != self.rejection_class, classes) - compressed = list(compress(zip(boxes, classes), keep)) - boxes, classes = map(list, zip(*compressed)) if compressed else ([], []) + compressed = list(compress(zip(boxes, classes, probas), keep)) + boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], []) predictions["bboxes"] = boxes predictions["classes"] = classes + predictions["probas"] = probas return predictions @@ -106,16 +114,22 @@ class Predictor: def format_predictions(self, outputs: Iterable): return map(self.__format_prediction, outputs) - def predict(self, images, threshold=None, format_output=False): + def __merge_boxes(self, predictions): + predictions = predictions_to_lpboxes(predictions) + return predictions + + def predict(self, images, threshold=None): if not threshold: threshold = read_config("threshold") predictions = infer(images, self.model, read_config("device"), threshold) + predictions = self.format_predictions(predictions) + if self.rejection_class: + predictions = self.filter_predictions(predictions) - if format_output: - predictions = self.format_predictions(predictions) - if self.rejection_class: - predictions = self.filter_predictions(predictions) + predictions = self.__merge_boxes(predictions) + + predictions = list(predictions) return predictions diff --git a/fb_detr/utils/box_merging.py b/fb_detr/utils/box_merging.py new file mode 100644 index 0000000..54c86d2 --- /dev/null +++ b/fb_detr/utils/box_merging.py @@ -0,0 +1,161 @@ +from collections import namedtuple +from itertools import starmap, combinations +from operator import attrgetter, itemgetter, truth +from frozendict import frozendict + +import numpy as np + +Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + +def compute_intersection(a, b): # returns None if rectangles don't intersect + + a = Rectangle(*a.values()) + b = Rectangle(*b.values()) + + dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin) + dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin) + + intrs = dx*dy if (dx>=0) and (dy>=0) else 0 + print("intrs", intrs) + return intrs + + +def compute_union(a, b): + def area(box): + r = Rectangle(*box.values()) + return (r.xmax - r.xmin) * (r.ymax - r.ymin) + + return area(a) + area(b) + + +def compute_iou(a, b): + return compute_intersection(a, b) / compute_union(a, b) + + +LPBox = namedtuple('LPBox', 'label proba box') + + +# def filter_contained(boxes, probas, iou_thresh=.9): +# +# def make_box_proba_pair(box, proba): +# return BoxProba(box.cpu().detach(), proba) +# +# current_boxes = set(starmap(make_box_proba_pair, zip(boxes, probas))) +# print(current_boxes) +# +# +# while True: +# print(len(current_boxes)) +# remaining_boxes = set() +# for ap, bp in combinations(current_boxes, r=2): +# a = ap.box +# b = bp.box +# if iou(a, b) > iou_thresh: +# remaining_boxes.add(ap) +# else: +# remaining_boxes |= {ap, bp} +# +# if len(remaining_boxes) == len(current_boxes): +# break +# else: +# current_boxes = remaining_boxes.copy() +# +# return current_boxes + + +# def filter_boxes(image, outputs, threshold=0.3): +# # keep only predictions with confidence >= threshold +# probas = outputs.logits.softmax(-1)[0, :, :-1] +# keep = probas.max(-1).values > threshold +# +# +# boxes = outputs.pred_boxes[0, keep].cpu() +# probas = probas[keep] +# +# filtered_boxes = filter_contained(boxes, probas) +# +# boxes = list(map(attrgetter("box"), filtered_boxes)) +# probas = list(map(attrgetter("proba"), filtered_boxes)) +# +# return boxes, probas + + +def keep(a, b, iou_thresh): + + iou = compute_iou(a.box, b.box) + print("iou", iou) + if iou > iou_thresh: + max_proba_box_idx = np.array(list(map(attrgetter("proba"), [a, b]))).argmax() + print("one") + return [a, b][max_proba_box_idx], None + else: + print("both") + return a, b + + +def filter_contained(lpboxes, iou_thresh=.1): + + current_boxes = {*lpboxes} + + while True: + print("current_boxes", len(current_boxes)) + remaining = set() + for a, b in combinations(current_boxes, r=2): + print() + for keeping in filter(truth, keep(a, b, iou_thresh=iou_thresh)): + remaining.add(keeping) + + print("remaining", len(remaining)) + if len(remaining) == len(current_boxes): + break + + current_boxes = {*remaining} + + return remaining + + +def lpboxes_to_dict(lpboxes): + + boxes = map(dict, map(attrgetter("box"), lpboxes)) + classes = map(attrgetter("label"), lpboxes) + probas = map(attrgetter("proba"), lpboxes) + + boxes, classes, probas = map(list, [boxes, classes, probas]) + + return { + "boxes": boxes, + "classes": classes, + "probas": probas + } + +def page_predictions_to_lpboxes(predictions): + boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) + boxes = map(frozendict, boxes) + lpboxes = list(starmap(LPBox, zip(classes, probas, boxes))) + lpboxes = filter_contained(lpboxes) + merged_predictions = lpboxes_to_dict(lpboxes) + predictions.update(merged_predictions) + return predictions + + +def predictions_to_lpboxes(predictions_per_page): + return map(page_predictions_to_lpboxes, predictions_per_page) + + + + + + + + + + + + + + + + + + + diff --git a/src/run_service.py b/src/run_service.py index be9ebcb..25f3720 100644 --- a/src/run_service.py +++ b/src/run_service.py @@ -47,7 +47,7 @@ def main(args): pdf = request.data pages = pdf2image.convert_from_bytes(pdf) - predictions = predictor.predict(pages, format_output=True) + predictions = predictor.predict(pages) return jsonify(list(predictions))