import argparse import logging from itertools import compress, starmap, chain from operator import itemgetter from pathlib import Path from typing import Iterable import torch from iteration_utilities import starfilter from tqdm import tqdm from detr.models import build_model from detr.prediction import get_args_parser, infer from fb_detr.config import CONFIG from fb_detr.utils.non_max_supprs import greedy_non_max_supprs from fb_detr.utils.stream import stream_pages, chunk_iterable, get_page_count def load_model(checkpoint_path): parser = argparse.ArgumentParser(parents=[get_args_parser()]) args = parser.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) device = torch.device(CONFIG.estimator.device) model, _, _ = build_model(args) checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) model.to(device) return model class Predictor: def __init__(self, checkpoint_path, classes=None, rejection_class=None): self.model = load_model(checkpoint_path) self.classes = classes self.rejection_class = rejection_class @staticmethod def __format_boxes(boxes): keys = "x1", "y1", "x2", "y2" x1s = boxes[:, 0].tolist() y1s = boxes[:, 1].tolist() x2s = boxes[:, 2].tolist() y2s = boxes[:, 3].tolist() boxes = [dict(zip(keys, vs)) for vs in zip(x1s, y1s, x2s, y2s)] return boxes @staticmethod def __normalize_to_list(maybe_multiple): return maybe_multiple if isinstance(maybe_multiple, tuple) else tuple([maybe_multiple]) def __format_classes(self, classes): if self.classes: return self.__normalize_to_list(itemgetter(*classes.tolist())(self.classes)) else: return classes.tolist() @staticmethod def __format_probas(probas): return probas.max(axis=1).tolist() def __format_prediction(self, predictions: dict): boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) if len(boxes): boxes = self.__format_boxes(boxes) classes = self.__format_classes(classes) probas = self.__format_probas(probas) else: boxes, classes, probas = [], [], [] predictions["bboxes"] = boxes predictions["classes"] = classes predictions["probas"] = probas return predictions def __filter_predictions_for_image(self, predictions): boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) if boxes: keep = map(lambda c: c != self.rejection_class, classes) compressed = list(compress(zip(boxes, classes, probas), keep)) boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], [], []) predictions["bboxes"] = boxes predictions["classes"] = classes predictions["probas"] = probas return predictions def filter_predictions(self, predictions): def detections_present(_, prediction): return bool(prediction["classes"]) # TODO: set page_idx even when not filtering def build_return_dict(page_idx, predictions): return {"page_idx": page_idx, **predictions} filtered_rejections = map(self.__filter_predictions_for_image, predictions) filtered_no_detections = starfilter(detections_present, enumerate(filtered_rejections)) filtered_no_detections = starmap(build_return_dict, filtered_no_detections) return filtered_no_detections def format_predictions(self, outputs: Iterable): return map(self.__format_prediction, outputs) def __non_max_supprs(self, predictions): predictions = map(greedy_non_max_supprs, predictions) return predictions def predict(self, images, threshold=None): if not threshold: threshold = CONFIG.estimator.threshold predictions = infer(images, self.model, CONFIG.estimator.device, threshold) predictions = self.format_predictions(predictions) if self.rejection_class: predictions = self.filter_predictions(predictions) predictions = self.__non_max_supprs(predictions) predictions = list(predictions) return predictions def predict_pdf(self, pdf: bytes): def progress(generator): page_count = get_page_count(pdf) batch_count = int(page_count / CONFIG.service.batch_size) yield from tqdm( generator, total=batch_count, position=1, leave=True ) if CONFIG.service.verbose else generator def predict_batch(batch_idx, batch): predictions = self.predict(batch) for p in predictions: p["page_idx"] += batch_idx return predictions page_stream = stream_pages(pdf) page_batches = chunk_iterable(page_stream, CONFIG.service.batch_size) predictions = list(chain(*starmap(predict_batch, progress(enumerate(page_batches))))) return predictions