import argparse from itertools import compress, starmap from operator import itemgetter from pathlib import Path from typing import Iterable import torch from detr.models import build_model from detr.test import get_args_parser, infer from iteration_utilities import starfilter from fb_detr.utils.non_max_supprs import greedy_non_max_supprs from fb_detr.utils.config import read_config def load_model(checkpoint_path): parser = argparse.ArgumentParser(parents=[get_args_parser()]) args = parser.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) device = torch.device(read_config("device")) model, _, _ = build_model(args) checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) model.to(device) return model class Predictor: def __init__(self, checkpoint_path, classes=None, rejection_class=None): self.model = load_model(checkpoint_path) self.classes = classes self.rejection_class = rejection_class @staticmethod def __format_boxes(boxes): keys = "x1", "y1", "x2", "y2" x1s = boxes[:, 0].tolist() y1s = boxes[:, 1].tolist() x2s = boxes[:, 2].tolist() y2s = boxes[:, 3].tolist() boxes = [dict(zip(keys, vs)) for vs in zip(x1s, y1s, x2s, y2s)] return boxes @staticmethod def __normalize_to_list(maybe_multiple): return maybe_multiple if isinstance(maybe_multiple, tuple) else tuple([maybe_multiple]) def __format_classes(self, classes): if self.classes: return self.__normalize_to_list(itemgetter(*classes.tolist())(self.classes)) else: return classes.tolist() @staticmethod def __format_probas(probas): return probas.max(axis=1).tolist() def __format_prediction(self, predictions: dict): boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) if len(boxes): boxes = self.__format_boxes(boxes) classes = self.__format_classes(classes) probas = self.__format_probas(probas) else: boxes, classes, probas = [], [], [] predictions["bboxes"] = boxes predictions["classes"] = classes predictions["probas"] = probas return predictions def __filter_predictions_for_image(self, predictions): boxes, classes, probas = itemgetter("bboxes", "classes", "probas")(predictions) if boxes: keep = map(lambda c: c != self.rejection_class, classes) compressed = list(compress(zip(boxes, classes, probas), keep)) boxes, classes, probas = map(list, zip(*compressed)) if compressed else ([], [], []) predictions["bboxes"] = boxes predictions["classes"] = classes predictions["probas"] = probas return predictions def filter_predictions(self, predictions): def detections_present(_, prediction): return bool(prediction["classes"]) def build_return_dict(page_idx, predictions): return {"page_idx": page_idx, **predictions} filtered_rejections = map(self.__filter_predictions_for_image, predictions) filtered_no_detections = starfilter(detections_present, enumerate(filtered_rejections)) filtered_no_detections = starmap(build_return_dict, filtered_no_detections) return filtered_no_detections def format_predictions(self, outputs: Iterable): return map(self.__format_prediction, outputs) def __merge_boxes(self, predictions): predictions = map(greedy_non_max_supprs, predictions) return predictions def predict(self, images, threshold=None): if not threshold: threshold = read_config("threshold") predictions = infer(images, self.model, read_config("device"), threshold) predictions = self.format_predictions(predictions) if self.rejection_class: predictions = self.filter_predictions(predictions) predictions = self.__merge_boxes(predictions) predictions = list(predictions) return predictions