Julius Unverfehrt a4fa73deaa Pull request #14: optional debug progress bar added
Merge in RR/fb-detr from add-debug-progress-bar to master

Squashed commit of the following:

commit 3449be1b46f73a5e9ae3719ed2821a1b7faca9e4
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Wed Feb 23 10:26:47 2022 +0100

    refactoring; added VERBOSE flag to config

commit e50234e205dfd7a40aaf7981da85e28048d9efba
Merge: 89703ca f6c51be
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Wed Feb 23 09:45:33 2022 +0100

    Merge branch 'config_changes' into add-debug-progress-bar

commit f6c51beeaa952c18c80b7af6b7a46b9de8f521c3
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Wed Feb 23 09:44:00 2022 +0100

    added env var

commit 89703caa776f0fad55757ab22568e45949b2b310
Author: Julius Unverfehrt <Julius.Unverfehrt@iqser.com>
Date:   Wed Feb 23 08:28:52 2022 +0100

    optional debug progress bar added
2022-02-23 10:51:24 +01:00

163 lines
5.1 KiB
Python

import argparse
import logging
from itertools import compress, starmap, chain
from operator import itemgetter
from pathlib import Path
from typing import Iterable
import torch
from iteration_utilities import starfilter
from tqdm import tqdm
from detr.models import build_model
from detr.prediction import get_args_parser, infer
from fb_detr.config import CONFIG
from fb_detr.utils.non_max_supprs import greedy_non_max_supprs
from fb_detr.utils.stream import stream_pages, chunk_iterable, get_page_count
def load_model(checkpoint_path):
parser = argparse.ArgumentParser(parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
device = torch.device(CONFIG.estimator.device)
model, _, _ = build_model(args)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
return model
class Predictor:
def __init__(self, checkpoint_path, classes=None, rejection_class=None):
self.model = load_model(checkpoint_path)
self.classes = classes
self.rejection_class = rejection_class
@staticmethod
def __format_boxes(boxes):
keys = "x1", "y1", "x2", "y2"
x1s = boxes[:, 0].tolist()
y1s = boxes[:, 1].tolist()
x2s = boxes[:, 2].tolist()
y2s = boxes[:, 3].tolist()
boxes = [dict(zip(keys, vs)) for vs in zip(x1s, y1s, x2s, y2s)]
return boxes
@staticmethod
def __normalize_to_list(maybe_multiple):
return maybe_multiple if isinstance(maybe_multiple, tuple) else tuple([maybe_multiple])
def __format_classes(self, classes):
if self.classes:
return self.__normalize_to_list(itemgetter(*classes.tolist())(self.classes))
else:
return classes.tolist()
@staticmethod
def __format_probas(probas):
return probas.max(axis=1).tolist()
def __format_prediction(self, predictions: dict):
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
if len(boxes):
boxes = self.__format_boxes(boxes)
classes = self.__format_classes(classes)
probas = self.__format_probas(probas)
else:
boxes, classes, probas = [], [], []
predictions["bboxes"] = boxes
predictions["classes"] = classes
predictions["probas"] = probas
return predictions
def __filter_predictions_for_image(self, predictions):
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
if boxes:
keep = map(lambda c: c != self.rejection_class, classes)
compressed = list(compress(zip(boxes, classes, probas), keep))
boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], [], [])
predictions["bboxes"] = boxes
predictions["classes"] = classes
predictions["probas"] = probas
return predictions
def filter_predictions(self, predictions):
def detections_present(_, prediction):
return bool(prediction["classes"])
# TODO: set page_idx even when not filtering
def build_return_dict(page_idx, predictions):
return {"page_idx": page_idx, **predictions}
filtered_rejections = map(self.__filter_predictions_for_image, predictions)
filtered_no_detections = starfilter(detections_present, enumerate(filtered_rejections))
filtered_no_detections = starmap(build_return_dict, filtered_no_detections)
return filtered_no_detections
def format_predictions(self, outputs: Iterable):
return map(self.__format_prediction, outputs)
def __non_max_supprs(self, predictions):
predictions = map(greedy_non_max_supprs, predictions)
return predictions
def predict(self, images, threshold=None):
if not threshold:
threshold = CONFIG.estimator.threshold
predictions = infer(images, self.model, CONFIG.estimator.device, threshold)
predictions = self.format_predictions(predictions)
if self.rejection_class:
predictions = self.filter_predictions(predictions)
predictions = self.__non_max_supprs(predictions)
predictions = list(predictions)
return predictions
def predict_pdf(self, pdf: bytes):
def progress(generator):
page_count = get_page_count(pdf)
batch_count = int(page_count / CONFIG.service.batch_size)
yield from tqdm(
generator, total=batch_count, position=1, leave=True
) if CONFIG.service.verbose else generator
def predict_batch(batch_idx, batch):
predictions = self.predict(batch)
for p in predictions:
p["page_idx"] += batch_idx
return predictions
page_stream = stream_pages(pdf)
page_batches = chunk_iterable(page_stream, CONFIG.service.batch_size)
predictions = list(chain(*starmap(predict_batch, progress(enumerate(page_batches)))))
return predictions