Matthias Bisping 8ebbe0e6a7 Pull request #2: Non max supprs
Merge in RR/fb_detr_prediction_container from non_max_supprs to master

Squashed commit of the following:

commit 9cc31b70e39412b3613a117228554608d947dbb5
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Feb 4 17:41:00 2022 +0100

    refactoring, renaming

commit ebc37299df598b71f7569d8e8473bdb66bbbbd1a
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Feb 4 17:34:26 2022 +0100

    renaming

commit d694866e1e98e6129f37eaf4c1950b962fed437f
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Feb 4 17:33:07 2022 +0100

    applied black

commit 381fe2dbf5d88f008d87bd807b84174376c5bcfe
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Feb 4 17:32:22 2022 +0100

    duplicate detection removal completed

commit ef2bab300322da3b12326d470f1c41263779e4a0
Author: Julius Unverfehrt <Julius.Unverfehrt@iqser.com>
Date:   Fri Feb 4 09:58:49 2022 +0100

    box merging algo  WIP

commit d770e56a7f31a28dea635816cae3b7b75fed0e24
Author: Julius Unverfehrt <Julius.Unverfehrt@iqser.com>
Date:   Fri Feb 4 09:37:17 2022 +0100

    refactor & box dropping working but algo is faulty & drops too much WIP

commit 289848871caadb4438f889b8a030f30cfb64201a
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Thu Feb 3 23:56:04 2022 +0100

    non max supprs WIP

commit 2f1ec100b2d33409e9178af8d53218b57d9bb0e2
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Thu Feb 3 13:32:22 2022 +0100

    changed Flask to not listen on public IP
2022-02-04 17:45:55 +01:00

136 lines
4.1 KiB
Python

import argparse
from itertools import compress, starmap
from operator import itemgetter
from pathlib import Path
from typing import Iterable
import torch
from detr.models import build_model
from detr.test import get_args_parser, infer
from iteration_utilities import starfilter
from fb_detr.utils.non_max_supprs import greedy_non_max_supprs
from fb_detr.utils.config import read_config
def load_model(checkpoint_path):
parser = argparse.ArgumentParser(parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
device = torch.device(read_config("device"))
model, _, _ = build_model(args)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
return model
class Predictor:
def __init__(self, checkpoint_path, classes=None, rejection_class=None):
self.model = load_model(checkpoint_path)
self.classes = classes
self.rejection_class = rejection_class
@staticmethod
def __format_boxes(boxes):
keys = "x1", "y1", "x2", "y2"
x1s = boxes[:, 0].tolist()
y1s = boxes[:, 1].tolist()
x2s = boxes[:, 2].tolist()
y2s = boxes[:, 3].tolist()
boxes = [dict(zip(keys, vs)) for vs in zip(x1s, y1s, x2s, y2s)]
return boxes
@staticmethod
def __normalize_to_list(maybe_multiple):
return maybe_multiple if isinstance(maybe_multiple, tuple) else tuple([maybe_multiple])
def __format_classes(self, classes):
if self.classes:
return self.__normalize_to_list(itemgetter(*classes.tolist())(self.classes))
else:
return classes.tolist()
@staticmethod
def __format_probas(probas):
return probas.max(axis=1).tolist()
def __format_prediction(self, predictions: dict):
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
if len(boxes):
boxes = self.__format_boxes(boxes)
classes = self.__format_classes(classes)
probas = self.__format_probas(probas)
else:
boxes, classes, probas = [], [], []
predictions["bboxes"] = boxes
predictions["classes"] = classes
predictions["probas"] = probas
return predictions
def __filter_predictions_for_image(self, predictions):
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
if boxes:
keep = map(lambda c: c != self.rejection_class, classes)
compressed = list(compress(zip(boxes, classes, probas), keep))
boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], [], [])
predictions["bboxes"] = boxes
predictions["classes"] = classes
predictions["probas"] = probas
return predictions
def filter_predictions(self, predictions):
def detections_present(_, prediction):
return bool(prediction["classes"])
def build_return_dict(page_idx, predictions):
return {"page_idx": page_idx, **predictions}
filtered_rejections = map(self.__filter_predictions_for_image, predictions)
filtered_no_detections = starfilter(detections_present, enumerate(filtered_rejections))
filtered_no_detections = starmap(build_return_dict, filtered_no_detections)
return filtered_no_detections
def format_predictions(self, outputs: Iterable):
return map(self.__format_prediction, outputs)
def __merge_boxes(self, predictions):
predictions = map(greedy_non_max_supprs, predictions)
return predictions
def predict(self, images, threshold=None):
if not threshold:
threshold = read_config("threshold")
predictions = infer(images, self.model, read_config("device"), threshold)
predictions = self.format_predictions(predictions)
if self.rejection_class:
predictions = self.filter_predictions(predictions)
predictions = self.__merge_boxes(predictions)
predictions = list(predictions)
return predictions