diff --git a/fb_detr/predictor.py b/fb_detr/predictor.py index a217278..b7def19 100644 --- a/fb_detr/predictor.py +++ b/fb_detr/predictor.py @@ -9,7 +9,7 @@ from detr.models import build_model from detr.test import get_args_parser, infer from iteration_utilities import starfilter -from fb_detr.utils.box_merging import predictions_to_lpboxes +from fb_detr.utils.non_max_supprs import greedy_non_max_supprs from fb_detr.utils.config import read_config @@ -115,7 +115,7 @@ class Predictor: return map(self.__format_prediction, outputs) def __merge_boxes(self, predictions): - predictions = predictions_to_lpboxes(predictions) + predictions = map(greedy_non_max_supprs, predictions) return predictions def predict(self, images, threshold=None): diff --git a/fb_detr/utils/non_max_supprs.py b/fb_detr/utils/non_max_supprs.py index 1b94d11..55811c7 100644 --- a/fb_detr/utils/non_max_supprs.py +++ b/fb_detr/utils/non_max_supprs.py @@ -47,7 +47,7 @@ def overlap_too_much(a, b, iou_thresh): return iou > iou_thresh -def filter_contained(lpboxes, iou_thresh=0.1): +def __greedy_non_max_supprs(lpboxes, iou_thresh=0.1): def remove_less_likely(a, b): try: ll = less_likely(a, b) @@ -82,16 +82,15 @@ def lpboxes_to_dict(lpboxes): return {"bboxes": boxes, "classes": classes, "probas": probas} -def page_predictions_to_lpboxes(predictions): +def greedy_non_max_supprs(predictions): + boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) boxes = map(frozendict, boxes) lpboxes = list(starmap(LPBox, zip(classes, probas, boxes))) - lpboxes = filter_contained(lpboxes) + + lpboxes = __greedy_non_max_supprs(lpboxes) + merged_predictions = lpboxes_to_dict(lpboxes) predictions.update(merged_predictions) return predictions - - -def predictions_to_lpboxes(predictions_per_page): - return map(page_predictions_to_lpboxes, predictions_per_page)