Merge in RR/fb_detr_prediction_container from docker-image-tuning to master
Squashed commit of the following:
commit 9b30e6317aaf892fcb6f87275d03e2efb76954bf
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Mon Feb 21 15:17:01 2022 +0100
applied black
commit 84a57ac29723910dbc2c4d8ccce58c9d3131a305
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Mon Feb 21 15:15:15 2022 +0100
refactorig of tqdm
commit b26c52765c58125826099072d510a39baabce73e
Author: Julius Unverfehrt <Julius.Unverfehrt@iqser.com>
Date: Mon Feb 21 14:52:11 2022 +0100
correcting versioning of docker-compose
commit 23752eec0d95cc543f15a86c78bd8531ebfdde7d
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Mon Feb 21 14:49:26 2022 +0100
put tqdm progress in different place
commit e2e109ea7125c90f5b15ec374f3cbfef41e2ee9e
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Mon Feb 21 11:51:09 2022 +0100
fixed batching index bug
commit 6ca508ac55dd02ded356617653f580099e1cf186
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Mon Feb 21 11:37:34 2022 +0100
batching WIP
commit 0ceb7c1415b10230397f4860ac4e314d44bfbfd1
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Fri Feb 18 19:21:02 2022 +0100
debug mode for webserver renamed
commit 617f07a0296ad3efc85b6ee52d1641cdfa22d3d3
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Fri Feb 18 18:25:39 2022 +0100
refactoring, better logging, added compose file for local testing
commit a24f799614e22481dd20b578c354e33474bec5c0
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Fri Feb 18 17:31:14 2022 +0100
updated submodule
commit 67b64606e081373e5c30ccf5bfafcb91dcc9a74e
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Fri Feb 18 17:29:20 2022 +0100
cleanup: better config; refactoring; renaming
commit c3a1ab560879d6a1e6ce003c74a07d62175316f7
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Fri Feb 18 15:55:35 2022 +0100
tweaked dockerfiles
commit 43f7a32265243bc0f110bd307325b5404e8726a8
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Fri Feb 18 15:02:49 2022 +0100
added dockerignore
157 lines
4.9 KiB
Python
157 lines
4.9 KiB
Python
import argparse
|
|
import logging
|
|
from itertools import compress, starmap, chain
|
|
from operator import itemgetter
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
import torch
|
|
from iteration_utilities import starfilter
|
|
from tqdm import tqdm
|
|
|
|
from detr.models import build_model
|
|
from detr.prediction import get_args_parser, infer
|
|
from fb_detr.config import CONFIG
|
|
from fb_detr.utils.non_max_supprs import greedy_non_max_supprs
|
|
from fb_detr.utils.stream import stream_pages, chunk_iterable, get_page_count
|
|
|
|
|
|
def load_model(checkpoint_path):
|
|
|
|
parser = argparse.ArgumentParser(parents=[get_args_parser()])
|
|
args = parser.parse_args()
|
|
|
|
if args.output_dir:
|
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
device = torch.device(CONFIG.estimator.device)
|
|
|
|
model, _, _ = build_model(args)
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
model.load_state_dict(checkpoint["model"])
|
|
|
|
model.to(device)
|
|
|
|
return model
|
|
|
|
|
|
class Predictor:
|
|
def __init__(self, checkpoint_path, classes=None, rejection_class=None):
|
|
self.model = load_model(checkpoint_path)
|
|
self.classes = classes
|
|
self.rejection_class = rejection_class
|
|
|
|
@staticmethod
|
|
def __format_boxes(boxes):
|
|
|
|
keys = "x1", "y1", "x2", "y2"
|
|
|
|
x1s = boxes[:, 0].tolist()
|
|
y1s = boxes[:, 1].tolist()
|
|
x2s = boxes[:, 2].tolist()
|
|
y2s = boxes[:, 3].tolist()
|
|
|
|
boxes = [dict(zip(keys, vs)) for vs in zip(x1s, y1s, x2s, y2s)]
|
|
|
|
return boxes
|
|
|
|
@staticmethod
|
|
def __normalize_to_list(maybe_multiple):
|
|
return maybe_multiple if isinstance(maybe_multiple, tuple) else tuple([maybe_multiple])
|
|
|
|
def __format_classes(self, classes):
|
|
if self.classes:
|
|
return self.__normalize_to_list(itemgetter(*classes.tolist())(self.classes))
|
|
else:
|
|
return classes.tolist()
|
|
|
|
@staticmethod
|
|
def __format_probas(probas):
|
|
return probas.max(axis=1).tolist()
|
|
|
|
def __format_prediction(self, predictions: dict):
|
|
|
|
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
|
|
|
|
if len(boxes):
|
|
boxes = self.__format_boxes(boxes)
|
|
classes = self.__format_classes(classes)
|
|
probas = self.__format_probas(probas)
|
|
else:
|
|
boxes, classes, probas = [], [], []
|
|
|
|
predictions["bboxes"] = boxes
|
|
predictions["classes"] = classes
|
|
predictions["probas"] = probas
|
|
|
|
return predictions
|
|
|
|
def __filter_predictions_for_image(self, predictions):
|
|
|
|
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
|
|
|
|
if boxes:
|
|
keep = map(lambda c: c != self.rejection_class, classes)
|
|
compressed = list(compress(zip(boxes, classes, probas), keep))
|
|
boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], [], [])
|
|
predictions["bboxes"] = boxes
|
|
predictions["classes"] = classes
|
|
predictions["probas"] = probas
|
|
|
|
return predictions
|
|
|
|
def filter_predictions(self, predictions):
|
|
def detections_present(_, prediction):
|
|
return bool(prediction["classes"])
|
|
|
|
# TODO: set page_idx even when not filtering
|
|
def build_return_dict(page_idx, predictions):
|
|
return {"page_idx": page_idx, **predictions}
|
|
|
|
filtered_rejections = map(self.__filter_predictions_for_image, predictions)
|
|
filtered_no_detections = starfilter(detections_present, enumerate(filtered_rejections))
|
|
filtered_no_detections = starmap(build_return_dict, filtered_no_detections)
|
|
|
|
return filtered_no_detections
|
|
|
|
def format_predictions(self, outputs: Iterable):
|
|
return map(self.__format_prediction, outputs)
|
|
|
|
def __non_max_supprs(self, predictions):
|
|
predictions = map(greedy_non_max_supprs, predictions)
|
|
return predictions
|
|
|
|
def predict(self, images, threshold=None):
|
|
|
|
if not threshold:
|
|
threshold = CONFIG.estimator.threshold
|
|
|
|
predictions = infer(images, self.model, CONFIG.estimator.device, threshold)
|
|
predictions = self.format_predictions(predictions)
|
|
if self.rejection_class:
|
|
predictions = self.filter_predictions(predictions)
|
|
|
|
predictions = self.__non_max_supprs(predictions)
|
|
|
|
predictions = list(predictions)
|
|
|
|
return predictions
|
|
|
|
def predict_pdf(self, pdf: bytes):
|
|
def predict_batch(batch_idx, batch):
|
|
predictions = self.predict(batch)
|
|
for p in predictions:
|
|
p["page_idx"] += batch_idx
|
|
|
|
return predictions
|
|
|
|
page_count = get_page_count(pdf)
|
|
batch_count = int(page_count / CONFIG.service.batch_size)
|
|
|
|
page_stream = stream_pages(pdf)
|
|
page_batches = chunk_iterable(page_stream, CONFIG.service.batch_size)
|
|
predictions = list(chain(*starmap(predict_batch, tqdm(enumerate(page_batches), total=batch_count))))
|
|
|
|
return predictions
|