Matthias Bisping 3e2cb94060 Pull request #9: Docker image tuning, batching of pdf pages and misc other
Merge in RR/fb_detr_prediction_container from docker-image-tuning to master

Squashed commit of the following:

commit 9b30e6317aaf892fcb6f87275d03e2efb76954bf
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Feb 21 15:17:01 2022 +0100

    applied black

commit 84a57ac29723910dbc2c4d8ccce58c9d3131a305
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Feb 21 15:15:15 2022 +0100

    refactorig of tqdm

commit b26c52765c58125826099072d510a39baabce73e
Author: Julius Unverfehrt <Julius.Unverfehrt@iqser.com>
Date:   Mon Feb 21 14:52:11 2022 +0100

    correcting versioning of docker-compose

commit 23752eec0d95cc543f15a86c78bd8531ebfdde7d
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Feb 21 14:49:26 2022 +0100

    put tqdm progress in different place

commit e2e109ea7125c90f5b15ec374f3cbfef41e2ee9e
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Feb 21 11:51:09 2022 +0100

    fixed batching index bug

commit 6ca508ac55dd02ded356617653f580099e1cf186
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Feb 21 11:37:34 2022 +0100

    batching WIP

commit 0ceb7c1415b10230397f4860ac4e314d44bfbfd1
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Feb 18 19:21:02 2022 +0100

    debug mode for webserver renamed

commit 617f07a0296ad3efc85b6ee52d1641cdfa22d3d3
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Feb 18 18:25:39 2022 +0100

    refactoring, better logging, added compose file for local testing

commit a24f799614e22481dd20b578c354e33474bec5c0
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Feb 18 17:31:14 2022 +0100

    updated submodule

commit 67b64606e081373e5c30ccf5bfafcb91dcc9a74e
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Feb 18 17:29:20 2022 +0100

    cleanup: better config; refactoring; renaming

commit c3a1ab560879d6a1e6ce003c74a07d62175316f7
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Feb 18 15:55:35 2022 +0100

    tweaked dockerfiles

commit 43f7a32265243bc0f110bd307325b5404e8726a8
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Fri Feb 18 15:02:49 2022 +0100

    added dockerignore
2022-02-21 15:36:38 +01:00

157 lines
4.9 KiB
Python

import argparse
import logging
from itertools import compress, starmap, chain
from operator import itemgetter
from pathlib import Path
from typing import Iterable
import torch
from iteration_utilities import starfilter
from tqdm import tqdm
from detr.models import build_model
from detr.prediction import get_args_parser, infer
from fb_detr.config import CONFIG
from fb_detr.utils.non_max_supprs import greedy_non_max_supprs
from fb_detr.utils.stream import stream_pages, chunk_iterable, get_page_count
def load_model(checkpoint_path):
parser = argparse.ArgumentParser(parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
device = torch.device(CONFIG.estimator.device)
model, _, _ = build_model(args)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
return model
class Predictor:
def __init__(self, checkpoint_path, classes=None, rejection_class=None):
self.model = load_model(checkpoint_path)
self.classes = classes
self.rejection_class = rejection_class
@staticmethod
def __format_boxes(boxes):
keys = "x1", "y1", "x2", "y2"
x1s = boxes[:, 0].tolist()
y1s = boxes[:, 1].tolist()
x2s = boxes[:, 2].tolist()
y2s = boxes[:, 3].tolist()
boxes = [dict(zip(keys, vs)) for vs in zip(x1s, y1s, x2s, y2s)]
return boxes
@staticmethod
def __normalize_to_list(maybe_multiple):
return maybe_multiple if isinstance(maybe_multiple, tuple) else tuple([maybe_multiple])
def __format_classes(self, classes):
if self.classes:
return self.__normalize_to_list(itemgetter(*classes.tolist())(self.classes))
else:
return classes.tolist()
@staticmethod
def __format_probas(probas):
return probas.max(axis=1).tolist()
def __format_prediction(self, predictions: dict):
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
if len(boxes):
boxes = self.__format_boxes(boxes)
classes = self.__format_classes(classes)
probas = self.__format_probas(probas)
else:
boxes, classes, probas = [], [], []
predictions["bboxes"] = boxes
predictions["classes"] = classes
predictions["probas"] = probas
return predictions
def __filter_predictions_for_image(self, predictions):
boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions)
if boxes:
keep = map(lambda c: c != self.rejection_class, classes)
compressed = list(compress(zip(boxes, classes, probas), keep))
boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], [], [])
predictions["bboxes"] = boxes
predictions["classes"] = classes
predictions["probas"] = probas
return predictions
def filter_predictions(self, predictions):
def detections_present(_, prediction):
return bool(prediction["classes"])
# TODO: set page_idx even when not filtering
def build_return_dict(page_idx, predictions):
return {"page_idx": page_idx, **predictions}
filtered_rejections = map(self.__filter_predictions_for_image, predictions)
filtered_no_detections = starfilter(detections_present, enumerate(filtered_rejections))
filtered_no_detections = starmap(build_return_dict, filtered_no_detections)
return filtered_no_detections
def format_predictions(self, outputs: Iterable):
return map(self.__format_prediction, outputs)
def __non_max_supprs(self, predictions):
predictions = map(greedy_non_max_supprs, predictions)
return predictions
def predict(self, images, threshold=None):
if not threshold:
threshold = CONFIG.estimator.threshold
predictions = infer(images, self.model, CONFIG.estimator.device, threshold)
predictions = self.format_predictions(predictions)
if self.rejection_class:
predictions = self.filter_predictions(predictions)
predictions = self.__non_max_supprs(predictions)
predictions = list(predictions)
return predictions
def predict_pdf(self, pdf: bytes):
def predict_batch(batch_idx, batch):
predictions = self.predict(batch)
for p in predictions:
p["page_idx"] += batch_idx
return predictions
page_count = get_page_count(pdf)
batch_count = int(page_count / CONFIG.service.batch_size)
page_stream = stream_pages(pdf)
page_batches = chunk_iterable(page_stream, CONFIG.service.batch_size)
predictions = list(chain(*starmap(predict_batch, tqdm(enumerate(page_batches), total=batch_count))))
return predictions