Pull request #2: Non max supprs

Merge in RR/fb_detr_prediction_container from non_max_supprs to master

Squashed commit of the following:

commit 9cc31b70e39412b3613a117228554608d947dbb5
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Feb 4 17:41:00 2022 +0100

    refactoring, renaming

commit ebc37299df598b71f7569d8e8473bdb66bbbbd1a
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Feb 4 17:34:26 2022 +0100

    renaming

commit d694866e1e98e6129f37eaf4c1950b962fed437f
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Feb 4 17:33:07 2022 +0100

    applied black

commit 381fe2dbf5d88f008d87bd807b84174376c5bcfe
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Feb 4 17:32:22 2022 +0100

    duplicate detection removal completed

commit ef2bab300322da3b12326d470f1c41263779e4a0
Author: Julius Unverfehrt <Julius.Unverfehrt@iqser.com>
Date:   Fri Feb 4 09:58:49 2022 +0100

    box merging algo  WIP

commit d770e56a7f31a28dea635816cae3b7b75fed0e24
Author: Julius Unverfehrt <Julius.Unverfehrt@iqser.com>
Date:   Fri Feb 4 09:37:17 2022 +0100

    refactor & box dropping working but algo is faulty & drops too much WIP

commit 289848871caadb4438f889b8a030f30cfb64201a
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Thu Feb 3 23:56:04 2022 +0100

    non max supprs WIP

commit 2f1ec100b2d33409e9178af8d53218b57d9bb0e2
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Thu Feb 3 13:32:22 2022 +0100

    changed Flask to not listen on public IP
This commit is contained in:
Matthias Bisping 2022-02-04 17:45:55 +01:00 committed by Julius Unverfehrt
parent e4dc6631b5
commit 8ebbe0e6a7
6 changed files with 140 additions and 29 deletions

View File

@ -9,6 +9,7 @@ from detr.models import build_model
from detr.test import get_args_parser, infer
from iteration_utilities import starfilter
from fb_detr.utils.non_max_supprs import greedy_non_max_supprs
from fb_detr.utils.config import read_config
@ -62,31 +63,38 @@ class Predictor:
else:
return classes.tolist()
def __format_prediction(self, output: dict):
@staticmethod
def __format_probas(probas):
return probas.max(axis=1).tolist()
boxes, classes = itemgetter("bboxes", "classes")(output)
def __format_prediction(self, predictions: dict):
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
if len(boxes):
boxes = self.__format_boxes(boxes)
classes = self.__format_classes(classes)
probas = self.__format_probas(probas)
else:
boxes, classes = [], []
boxes, classes, probas = [], [], []
output["bboxes"] = boxes
output["classes"] = classes
predictions["bboxes"] = boxes
predictions["classes"] = classes
predictions["probas"] = probas
return output
return predictions
def __filter_predictions_for_image(self, predictions):
boxes, classes = itemgetter("bboxes", "classes")(predictions)
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
if boxes:
keep = map(lambda c: c != self.rejection_class, classes)
compressed = list(compress(zip(boxes, classes), keep))
boxes, classes = map(list, zip(*compressed)) if compressed else ([], [])
compressed = list(compress(zip(boxes, classes, probas), keep))
boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], [], [])
predictions["bboxes"] = boxes
predictions["classes"] = classes
predictions["probas"] = probas
return predictions
@ -106,16 +114,22 @@ class Predictor:
def format_predictions(self, outputs: Iterable):
return map(self.__format_prediction, outputs)
def predict(self, images, threshold=None, format_output=False):
def __merge_boxes(self, predictions):
predictions = map(greedy_non_max_supprs, predictions)
return predictions
def predict(self, images, threshold=None):
if not threshold:
threshold = read_config("threshold")
predictions = infer(images, self.model, read_config("device"), threshold)
predictions = self.format_predictions(predictions)
if self.rejection_class:
predictions = self.filter_predictions(predictions)
if format_output:
predictions = self.format_predictions(predictions)
if self.rejection_class:
predictions = self.filter_predictions(predictions)
predictions = self.__merge_boxes(predictions)
predictions = list(predictions)
return predictions

View File

@ -0,0 +1,96 @@
from collections import namedtuple
from itertools import starmap, combinations
from operator import attrgetter, itemgetter
from frozendict import frozendict
Rectangle = namedtuple("Rectangle", "xmin ymin xmax ymax")
def make_box(x1, y1, x2, y2):
keys = "x1", "y1", "x2", "y2"
return dict(zip(keys, [x1, y1, x2, y2]))
def compute_intersection(a, b): # returns None if rectangles don't intersect
a = Rectangle(*a.values())
b = Rectangle(*b.values())
dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin)
dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin)
return dx * dy if (dx >= 0) and (dy >= 0) else 0
def compute_union(a, b):
def area(box):
r = Rectangle(*box.values())
return (r.xmax - r.xmin) * (r.ymax - r.ymin)
return (area(a) + area(b)) - compute_intersection(a, b)
def compute_iou(a, b):
return compute_intersection(a, b) / compute_union(a, b)
LPBox = namedtuple("LPBox", "label proba box")
def less_likely(a, b):
return min([a, b], key=attrgetter("proba"))
def overlap_too_much(a, b, iou_thresh):
iou = compute_iou(a.box, b.box)
return iou > iou_thresh
def __greedy_non_max_supprs(lpboxes, iou_thresh=0.1):
def remove_less_likely(a, b):
try:
ll = less_likely(a, b)
current_boxes.remove(ll)
except KeyError:
pass
current_boxes = {*lpboxes}
while True:
n = len(current_boxes)
for a, b in combinations(current_boxes, r=2):
if len({a, b} & current_boxes) != 2:
continue
if overlap_too_much(a, b, iou_thresh):
remove_less_likely(a, b)
if n == len(current_boxes):
break
return current_boxes
def lpboxes_to_dict(lpboxes):
boxes = map(dict, map(attrgetter("box"), lpboxes))
classes = map(attrgetter("label"), lpboxes)
probas = map(attrgetter("proba"), lpboxes)
boxes, classes, probas = map(list, [boxes, classes, probas])
return {"bboxes": boxes, "classes": classes, "probas": probas}
def greedy_non_max_supprs(predictions):
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
boxes = map(frozendict, boxes)
lpboxes = list(starmap(LPBox, zip(classes, probas, boxes)))
lpboxes = __greedy_non_max_supprs(lpboxes)
merged_predictions = lpboxes_to_dict(lpboxes)
predictions.update(merged_predictions)
return predictions

@ -1 +1 @@
Subproject commit 7e3258ccc1fa2be7a9d8ab333873b79de7005809
Subproject commit c17cddd980ae3003a2633a65744d2265228e4c71

View File

@ -12,3 +12,4 @@ requests==2.27.1
iteration-utilities==0.11.0
dvc==2.9.3
dvc[ssh]
frozendict==2.3.0

View File

@ -4,20 +4,23 @@ from operator import itemgetter
import pdf2image
import requests
from PIL import ImageDraw
from PIL import ImageDraw, ImageFont
def draw_coco_box(draw: ImageDraw.Draw, bbox, klass):
def draw_coco_box(draw: ImageDraw.Draw, bbox, klass, proba):
x1, y1, x2, y2 = itemgetter("x1", "y1", "x2", "y2")(bbox)
draw.rectangle(((x1, y1), (x2, y2)), outline="red")
draw.text((x1, y1), text=klass, fill=(0, 0, 0, 100))
fnt = ImageFont.truetype("Pillow/Tests/fonts/FreeMono.ttf", 30)
draw.text((x1, y2), text=f"{klass}: {proba:.2f}", fill=(0, 0, 0, 100), font=fnt)
def draw_coco_boxes(image, bboxes, classes):
def draw_coco_boxes(image, bboxes, classes, probas):
draw = ImageDraw.Draw(image)
for bbox, klass in zip(bboxes, classes):
draw_coco_box(draw, bbox, klass)
for bbox, klass, proba in zip(bboxes, classes, probas):
draw_coco_box(draw, bbox, klass, proba)
return image
@ -26,9 +29,9 @@ def annotate(pdf_path, predictions):
pages = pdf2image.convert_from_path(pdf_path)
for prd in predictions:
page_idx, boxes, classes = itemgetter("page_idx", "bboxes", "classes")(prd)
page_idx, boxes, classes, probas = itemgetter("page_idx", "bboxes", "classes", "probas")(prd)
page = pages[page_idx]
image = draw_coco_boxes(page, boxes, classes)
image = draw_coco_boxes(page, boxes, classes, probas)
image.save(f"/tmp/serv_out/{page_idx}.png")
@ -42,14 +45,11 @@ def parse_args():
def main(args):
response = requests.post("http://0.0.0.0:8080", data=open(args.pdf_path, "rb"))
response = requests.post("http://127.0.0.1:5000", data=open(args.pdf_path, "rb"))
response.raise_for_status()
predictions = response.json()
print(json.dumps(predictions, indent=2))
annotate(args.pdf_path, predictions)

View File

@ -47,7 +47,7 @@ def main(args):
pdf = request.data
pages = pdf2image.convert_from_bytes(pdf)
predictions = predictor.predict(pages, format_output=True)
predictions = predictor.predict(pages)
return jsonify(list(predictions))
@ -58,7 +58,7 @@ def main(args):
predictor = initialize_predictor()
app.run(host="0.0.0.0", port=8080)
app.run(host="127.0.0.1", port=5000)
if __name__ == "__main__":