Merge in RR/fb_detr_prediction_container from setup to master
Squashed commit of the following:
commit 7fae4878d4250676367b7201fa163a4b67f79f84
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Thu Feb 3 11:22:12 2022 +0100
readded annotation to client
commit ff788030f6b3b342919a7fd31dfa66940033d7e1
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Thu Feb 3 11:15:16 2022 +0100
applied black
commit 3521444f678950a2772b725c6964751e0e655736
Merge: 4080aff 51d6597
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Thu Feb 3 10:39:11 2022 +0100
Merge branch 'setup' of ssh://git.iqser.com:2222/rr/fb_detr_prediction_container into setup
commit 4080affd21a02ad32c61fbd2027511f51a202d63
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Thu Feb 3 10:39:02 2022 +0100
added poppler-utils download to Dockerfile, since pdf2image only is a wrapper for it
commit 51d6597b056ae9ac693280f65a3f37d46b1276cf
Author: Julius Unverfehrt <Julius.Unverfehrt@iqser.com>
Date: Thu Feb 3 09:43:35 2022 +0100
Structure change for local backbone lookup (working now)
commit ac314d5148d6e026c67f00df45a8bbc70c15b52d
Author: Julius Unverfehrt <Julius.Unverfehrt@iqser.com>
Date: Thu Feb 3 09:35:41 2022 +0100
env bug fixed
commit 1c3221fe4956911b29fd8fede8d07dcdefad06d8
Author: Julius Unverfehrt <Julius.Unverfehrt@iqser.com>
Date: Thu Feb 3 09:23:55 2022 +0100
ENV correctly set now
commit 58069440583f1f78cfb2fb796fa4dc4a63e2916a
Author: Julius Unverfehrt <Julius.Unverfehrt@iqser.com>
Date: Thu Feb 3 08:41:29 2022 +0100
ENV for local torch model lookup set
commit f0501cf0bf904793e8e04afbd3d80ee84af9d981
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Feb 2 18:28:44 2022 +0100
changed host and port for flask
commit 986fda22f6656b10930628d0d284995b33ea2df5
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Feb 2 17:33:07 2022 +0100
added debug webserver method
commit 64b857ce53757ec2b7e7c327962fa65b551603a0
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Feb 2 16:59:11 2022 +0100
moved utils into module; fixed open-cv (maybe)
commit c62ada183135e12b41a29c6822472e33698f947f
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Feb 2 15:55:10 2022 +0100
made bash scripts executable
commit 982bdd7503c14fcf1776ae10c38589475199545e
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Feb 2 15:35:16 2022 +0100
service building logic added (WIP)
commit 46e5e3b8e67e54ecedaeee4765a3437f08fa4b17
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Feb 2 14:37:28 2022 +0100
applied black
commit ad93130e66d2e87bc86b2bf1de6234f3c037df48
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Feb 2 14:36:09 2022 +0100
fixed formatting (w, h -> x2, y2); added drawing logic to caller mock
commit df76f033599e66aaa52143f5e2b156530f643df9
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Feb 2 13:54:34 2022 +0100
page indices in predictions
commit 5e87c57dff752419486d1a44de9a734e3f840816
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Feb 2 13:17:34 2022 +0100
service main loop WIP (working in basic version)
commit ba5ec3d57621d090201413309126955940602be9
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Feb 2 13:03:52 2022 +0100
service main loop WIP
commit 77266f6982ec826eadcdd8a18c5ccf0fc380611b
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Feb 2 11:24:27 2022 +0100
fixed bug for self.classes == None
commit 858ef7589d6914ad503660a3ddc5e75bf72a6bb7
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Feb 2 11:09:11 2022 +0100
removed 'postprocessors' argument and attribute
... and 32 more commits
122 lines
3.6 KiB
Python
122 lines
3.6 KiB
Python
import argparse
|
|
from itertools import compress, starmap
|
|
from operator import itemgetter
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
import torch
|
|
from detr.models import build_model
|
|
from detr.test import get_args_parser, infer
|
|
from iteration_utilities import starfilter
|
|
|
|
from fb_detr.utils.config import read_config
|
|
|
|
|
|
def load_model(checkpoint_path):
|
|
|
|
parser = argparse.ArgumentParser(parents=[get_args_parser()])
|
|
args = parser.parse_args()
|
|
|
|
if args.output_dir:
|
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
device = torch.device(read_config("device"))
|
|
|
|
model, _, _ = build_model(args)
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
model.load_state_dict(checkpoint["model"])
|
|
|
|
model.to(device)
|
|
|
|
return model
|
|
|
|
|
|
class Predictor:
|
|
def __init__(self, checkpoint_path, classes=None, rejection_class=None):
|
|
self.model = load_model(checkpoint_path)
|
|
self.classes = classes
|
|
self.rejection_class = rejection_class
|
|
|
|
@staticmethod
|
|
def __format_boxes(boxes):
|
|
|
|
keys = "x1", "y1", "x2", "y2"
|
|
|
|
x1s = boxes[:, 0].tolist()
|
|
y1s = boxes[:, 1].tolist()
|
|
x2s = boxes[:, 2].tolist()
|
|
y2s = boxes[:, 3].tolist()
|
|
|
|
boxes = [dict(zip(keys, vs)) for vs in zip(x1s, y1s, x2s, y2s)]
|
|
|
|
return boxes
|
|
|
|
@staticmethod
|
|
def __normalize_to_list(maybe_multiple):
|
|
return maybe_multiple if isinstance(maybe_multiple, tuple) else tuple([maybe_multiple])
|
|
|
|
def __format_classes(self, classes):
|
|
if self.classes:
|
|
return self.__normalize_to_list(itemgetter(*classes.tolist())(self.classes))
|
|
else:
|
|
return classes.tolist()
|
|
|
|
def __format_prediction(self, output: dict):
|
|
|
|
boxes, classes = itemgetter("bboxes", "classes")(output)
|
|
|
|
if len(boxes):
|
|
boxes = self.__format_boxes(boxes)
|
|
classes = self.__format_classes(classes)
|
|
else:
|
|
boxes, classes = [], []
|
|
|
|
output["bboxes"] = boxes
|
|
output["classes"] = classes
|
|
|
|
return output
|
|
|
|
def __filter_predictions_for_image(self, predictions):
|
|
|
|
boxes, classes = itemgetter("bboxes", "classes")(predictions)
|
|
|
|
if boxes:
|
|
keep = map(lambda c: c != self.rejection_class, classes)
|
|
compressed = list(compress(zip(boxes, classes), keep))
|
|
boxes, classes = map(list, zip(*compressed)) if compressed else ([], [])
|
|
predictions["bboxes"] = boxes
|
|
predictions["classes"] = classes
|
|
|
|
return predictions
|
|
|
|
def filter_predictions(self, predictions):
|
|
def detections_present(_, prediction):
|
|
return bool(prediction["classes"])
|
|
|
|
def build_return_dict(page_idx, predictions):
|
|
return {"page_idx": page_idx, **predictions}
|
|
|
|
filtered_rejections = map(self.__filter_predictions_for_image, predictions)
|
|
filtered_no_detections = starfilter(detections_present, enumerate(filtered_rejections))
|
|
filtered_no_detections = starmap(build_return_dict, filtered_no_detections)
|
|
|
|
return filtered_no_detections
|
|
|
|
def format_predictions(self, outputs: Iterable):
|
|
return map(self.__format_prediction, outputs)
|
|
|
|
def predict(self, images, threshold=None, format_output=False):
|
|
|
|
if not threshold:
|
|
threshold = read_config("threshold")
|
|
|
|
predictions = infer(images, self.model, read_config("device"), threshold)
|
|
|
|
if format_output:
|
|
predictions = self.format_predictions(predictions)
|
|
if self.rejection_class:
|
|
predictions = self.filter_predictions(predictions)
|
|
|
|
return predictions
|