refactoring, renaming
This commit is contained in:
parent
ebc37299df
commit
9cc31b70e3
@ -9,7 +9,7 @@ from detr.models import build_model
|
||||
from detr.test import get_args_parser, infer
|
||||
from iteration_utilities import starfilter
|
||||
|
||||
from fb_detr.utils.box_merging import predictions_to_lpboxes
|
||||
from fb_detr.utils.non_max_supprs import greedy_non_max_supprs
|
||||
from fb_detr.utils.config import read_config
|
||||
|
||||
|
||||
@ -115,7 +115,7 @@ class Predictor:
|
||||
return map(self.__format_prediction, outputs)
|
||||
|
||||
def __merge_boxes(self, predictions):
|
||||
predictions = predictions_to_lpboxes(predictions)
|
||||
predictions = map(greedy_non_max_supprs, predictions)
|
||||
return predictions
|
||||
|
||||
def predict(self, images, threshold=None):
|
||||
|
||||
@ -47,7 +47,7 @@ def overlap_too_much(a, b, iou_thresh):
|
||||
return iou > iou_thresh
|
||||
|
||||
|
||||
def filter_contained(lpboxes, iou_thresh=0.1):
|
||||
def __greedy_non_max_supprs(lpboxes, iou_thresh=0.1):
|
||||
def remove_less_likely(a, b):
|
||||
try:
|
||||
ll = less_likely(a, b)
|
||||
@ -82,16 +82,15 @@ def lpboxes_to_dict(lpboxes):
|
||||
return {"bboxes": boxes, "classes": classes, "probas": probas}
|
||||
|
||||
|
||||
def page_predictions_to_lpboxes(predictions):
|
||||
def greedy_non_max_supprs(predictions):
|
||||
|
||||
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
|
||||
boxes = map(frozendict, boxes)
|
||||
lpboxes = list(starmap(LPBox, zip(classes, probas, boxes)))
|
||||
lpboxes = filter_contained(lpboxes)
|
||||
|
||||
lpboxes = __greedy_non_max_supprs(lpboxes)
|
||||
|
||||
merged_predictions = lpboxes_to_dict(lpboxes)
|
||||
predictions.update(merged_predictions)
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
def predictions_to_lpboxes(predictions_per_page):
|
||||
return map(page_predictions_to_lpboxes, predictions_per_page)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user