Pull request #2: Non max supprs
Merge in RR/fb_detr_prediction_container from non_max_supprs to master
Squashed commit of the following:
commit 9cc31b70e39412b3613a117228554608d947dbb5
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Fri Feb 4 17:41:00 2022 +0100
refactoring, renaming
commit ebc37299df598b71f7569d8e8473bdb66bbbbd1a
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Fri Feb 4 17:34:26 2022 +0100
renaming
commit d694866e1e98e6129f37eaf4c1950b962fed437f
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Fri Feb 4 17:33:07 2022 +0100
applied black
commit 381fe2dbf5d88f008d87bd807b84174376c5bcfe
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Fri Feb 4 17:32:22 2022 +0100
duplicate detection removal completed
commit ef2bab300322da3b12326d470f1c41263779e4a0
Author: Julius Unverfehrt <Julius.Unverfehrt@iqser.com>
Date: Fri Feb 4 09:58:49 2022 +0100
box merging algo WIP
commit d770e56a7f31a28dea635816cae3b7b75fed0e24
Author: Julius Unverfehrt <Julius.Unverfehrt@iqser.com>
Date: Fri Feb 4 09:37:17 2022 +0100
refactor & box dropping working but algo is faulty & drops too much WIP
commit 289848871caadb4438f889b8a030f30cfb64201a
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Thu Feb 3 23:56:04 2022 +0100
non max supprs WIP
commit 2f1ec100b2d33409e9178af8d53218b57d9bb0e2
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Thu Feb 3 13:32:22 2022 +0100
changed Flask to not listen on public IP
This commit is contained in:
parent
e4dc6631b5
commit
8ebbe0e6a7
@ -9,6 +9,7 @@ from detr.models import build_model
|
||||
from detr.test import get_args_parser, infer
|
||||
from iteration_utilities import starfilter
|
||||
|
||||
from fb_detr.utils.non_max_supprs import greedy_non_max_supprs
|
||||
from fb_detr.utils.config import read_config
|
||||
|
||||
|
||||
@ -62,31 +63,38 @@ class Predictor:
|
||||
else:
|
||||
return classes.tolist()
|
||||
|
||||
def __format_prediction(self, output: dict):
|
||||
@staticmethod
|
||||
def __format_probas(probas):
|
||||
return probas.max(axis=1).tolist()
|
||||
|
||||
boxes, classes = itemgetter("bboxes", "classes")(output)
|
||||
def __format_prediction(self, predictions: dict):
|
||||
|
||||
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
|
||||
|
||||
if len(boxes):
|
||||
boxes = self.__format_boxes(boxes)
|
||||
classes = self.__format_classes(classes)
|
||||
probas = self.__format_probas(probas)
|
||||
else:
|
||||
boxes, classes = [], []
|
||||
boxes, classes, probas = [], [], []
|
||||
|
||||
output["bboxes"] = boxes
|
||||
output["classes"] = classes
|
||||
predictions["bboxes"] = boxes
|
||||
predictions["classes"] = classes
|
||||
predictions["probas"] = probas
|
||||
|
||||
return output
|
||||
return predictions
|
||||
|
||||
def __filter_predictions_for_image(self, predictions):
|
||||
|
||||
boxes, classes = itemgetter("bboxes", "classes")(predictions)
|
||||
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
|
||||
|
||||
if boxes:
|
||||
keep = map(lambda c: c != self.rejection_class, classes)
|
||||
compressed = list(compress(zip(boxes, classes), keep))
|
||||
boxes, classes = map(list, zip(*compressed)) if compressed else ([], [])
|
||||
compressed = list(compress(zip(boxes, classes, probas), keep))
|
||||
boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], [], [])
|
||||
predictions["bboxes"] = boxes
|
||||
predictions["classes"] = classes
|
||||
predictions["probas"] = probas
|
||||
|
||||
return predictions
|
||||
|
||||
@ -106,16 +114,22 @@ class Predictor:
|
||||
def format_predictions(self, outputs: Iterable):
|
||||
return map(self.__format_prediction, outputs)
|
||||
|
||||
def predict(self, images, threshold=None, format_output=False):
|
||||
def __merge_boxes(self, predictions):
|
||||
predictions = map(greedy_non_max_supprs, predictions)
|
||||
return predictions
|
||||
|
||||
def predict(self, images, threshold=None):
|
||||
|
||||
if not threshold:
|
||||
threshold = read_config("threshold")
|
||||
|
||||
predictions = infer(images, self.model, read_config("device"), threshold)
|
||||
predictions = self.format_predictions(predictions)
|
||||
if self.rejection_class:
|
||||
predictions = self.filter_predictions(predictions)
|
||||
|
||||
if format_output:
|
||||
predictions = self.format_predictions(predictions)
|
||||
if self.rejection_class:
|
||||
predictions = self.filter_predictions(predictions)
|
||||
predictions = self.__merge_boxes(predictions)
|
||||
|
||||
predictions = list(predictions)
|
||||
|
||||
return predictions
|
||||
|
||||
96
fb_detr/utils/non_max_supprs.py
Normal file
96
fb_detr/utils/non_max_supprs.py
Normal file
@ -0,0 +1,96 @@
|
||||
from collections import namedtuple
|
||||
from itertools import starmap, combinations
|
||||
from operator import attrgetter, itemgetter
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
Rectangle = namedtuple("Rectangle", "xmin ymin xmax ymax")
|
||||
|
||||
|
||||
def make_box(x1, y1, x2, y2):
|
||||
keys = "x1", "y1", "x2", "y2"
|
||||
return dict(zip(keys, [x1, y1, x2, y2]))
|
||||
|
||||
|
||||
def compute_intersection(a, b): # returns None if rectangles don't intersect
|
||||
|
||||
a = Rectangle(*a.values())
|
||||
b = Rectangle(*b.values())
|
||||
|
||||
dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin)
|
||||
dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin)
|
||||
|
||||
return dx * dy if (dx >= 0) and (dy >= 0) else 0
|
||||
|
||||
|
||||
def compute_union(a, b):
|
||||
def area(box):
|
||||
r = Rectangle(*box.values())
|
||||
return (r.xmax - r.xmin) * (r.ymax - r.ymin)
|
||||
|
||||
return (area(a) + area(b)) - compute_intersection(a, b)
|
||||
|
||||
|
||||
def compute_iou(a, b):
|
||||
return compute_intersection(a, b) / compute_union(a, b)
|
||||
|
||||
|
||||
LPBox = namedtuple("LPBox", "label proba box")
|
||||
|
||||
|
||||
def less_likely(a, b):
|
||||
return min([a, b], key=attrgetter("proba"))
|
||||
|
||||
|
||||
def overlap_too_much(a, b, iou_thresh):
|
||||
iou = compute_iou(a.box, b.box)
|
||||
return iou > iou_thresh
|
||||
|
||||
|
||||
def __greedy_non_max_supprs(lpboxes, iou_thresh=0.1):
|
||||
def remove_less_likely(a, b):
|
||||
try:
|
||||
ll = less_likely(a, b)
|
||||
current_boxes.remove(ll)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
current_boxes = {*lpboxes}
|
||||
|
||||
while True:
|
||||
n = len(current_boxes)
|
||||
for a, b in combinations(current_boxes, r=2):
|
||||
if len({a, b} & current_boxes) != 2:
|
||||
continue
|
||||
if overlap_too_much(a, b, iou_thresh):
|
||||
remove_less_likely(a, b)
|
||||
|
||||
if n == len(current_boxes):
|
||||
break
|
||||
|
||||
return current_boxes
|
||||
|
||||
|
||||
def lpboxes_to_dict(lpboxes):
|
||||
|
||||
boxes = map(dict, map(attrgetter("box"), lpboxes))
|
||||
classes = map(attrgetter("label"), lpboxes)
|
||||
probas = map(attrgetter("proba"), lpboxes)
|
||||
|
||||
boxes, classes, probas = map(list, [boxes, classes, probas])
|
||||
|
||||
return {"bboxes": boxes, "classes": classes, "probas": probas}
|
||||
|
||||
|
||||
def greedy_non_max_supprs(predictions):
|
||||
|
||||
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
|
||||
boxes = map(frozendict, boxes)
|
||||
lpboxes = list(starmap(LPBox, zip(classes, probas, boxes)))
|
||||
|
||||
lpboxes = __greedy_non_max_supprs(lpboxes)
|
||||
|
||||
merged_predictions = lpboxes_to_dict(lpboxes)
|
||||
predictions.update(merged_predictions)
|
||||
|
||||
return predictions
|
||||
@ -1 +1 @@
|
||||
Subproject commit 7e3258ccc1fa2be7a9d8ab333873b79de7005809
|
||||
Subproject commit c17cddd980ae3003a2633a65744d2265228e4c71
|
||||
@ -12,3 +12,4 @@ requests==2.27.1
|
||||
iteration-utilities==0.11.0
|
||||
dvc==2.9.3
|
||||
dvc[ssh]
|
||||
frozendict==2.3.0
|
||||
|
||||
@ -4,20 +4,23 @@ from operator import itemgetter
|
||||
|
||||
import pdf2image
|
||||
import requests
|
||||
from PIL import ImageDraw
|
||||
from PIL import ImageDraw, ImageFont
|
||||
|
||||
|
||||
def draw_coco_box(draw: ImageDraw.Draw, bbox, klass):
|
||||
def draw_coco_box(draw: ImageDraw.Draw, bbox, klass, proba):
|
||||
x1, y1, x2, y2 = itemgetter("x1", "y1", "x2", "y2")(bbox)
|
||||
draw.rectangle(((x1, y1), (x2, y2)), outline="red")
|
||||
draw.text((x1, y1), text=klass, fill=(0, 0, 0, 100))
|
||||
|
||||
fnt = ImageFont.truetype("Pillow/Tests/fonts/FreeMono.ttf", 30)
|
||||
|
||||
draw.text((x1, y2), text=f"{klass}: {proba:.2f}", fill=(0, 0, 0, 100), font=fnt)
|
||||
|
||||
|
||||
def draw_coco_boxes(image, bboxes, classes):
|
||||
def draw_coco_boxes(image, bboxes, classes, probas):
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
for bbox, klass in zip(bboxes, classes):
|
||||
draw_coco_box(draw, bbox, klass)
|
||||
for bbox, klass, proba in zip(bboxes, classes, probas):
|
||||
draw_coco_box(draw, bbox, klass, proba)
|
||||
|
||||
return image
|
||||
|
||||
@ -26,9 +29,9 @@ def annotate(pdf_path, predictions):
|
||||
pages = pdf2image.convert_from_path(pdf_path)
|
||||
|
||||
for prd in predictions:
|
||||
page_idx, boxes, classes = itemgetter("page_idx", "bboxes", "classes")(prd)
|
||||
page_idx, boxes, classes, probas = itemgetter("page_idx", "bboxes", "classes", "probas")(prd)
|
||||
page = pages[page_idx]
|
||||
image = draw_coco_boxes(page, boxes, classes)
|
||||
image = draw_coco_boxes(page, boxes, classes, probas)
|
||||
image.save(f"/tmp/serv_out/{page_idx}.png")
|
||||
|
||||
|
||||
@ -42,14 +45,11 @@ def parse_args():
|
||||
|
||||
def main(args):
|
||||
|
||||
response = requests.post("http://0.0.0.0:8080", data=open(args.pdf_path, "rb"))
|
||||
|
||||
response = requests.post("http://127.0.0.1:5000", data=open(args.pdf_path, "rb"))
|
||||
response.raise_for_status()
|
||||
|
||||
predictions = response.json()
|
||||
|
||||
print(json.dumps(predictions, indent=2))
|
||||
|
||||
annotate(args.pdf_path, predictions)
|
||||
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ def main(args):
|
||||
pdf = request.data
|
||||
|
||||
pages = pdf2image.convert_from_bytes(pdf)
|
||||
predictions = predictor.predict(pages, format_output=True)
|
||||
predictions = predictor.predict(pages)
|
||||
|
||||
return jsonify(list(predictions))
|
||||
|
||||
@ -58,7 +58,7 @@ def main(args):
|
||||
|
||||
predictor = initialize_predictor()
|
||||
|
||||
app.run(host="0.0.0.0", port=8080)
|
||||
app.run(host="127.0.0.1", port=5000)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user