import argparse from itertools import compress, starmap from operator import itemgetter from pathlib import Path from typing import Iterable import torch from detr.models import build_model from detr.test import get_args_parser, infer from iteration_utilities import starfilter from fb_detr.utils.config import read_config def load_model(checkpoint_path): parser = argparse.ArgumentParser(parents=[get_args_parser()]) args = parser.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) device = torch.device(read_config("device")) model, _, _ = build_model(args) checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) model.to(device) return model class Predictor: def __init__(self, checkpoint_path, classes=None, rejection_class=None): self.model = load_model(checkpoint_path) self.classes = classes self.rejection_class = rejection_class @staticmethod def __format_boxes(boxes): keys = "x1", "y1", "x2", "y2" x1s = boxes[:, 0].tolist() y1s = boxes[:, 1].tolist() x2s = boxes[:, 2].tolist() y2s = boxes[:, 3].tolist() boxes = [dict(zip(keys, vs)) for vs in zip(x1s, y1s, x2s, y2s)] return boxes @staticmethod def __normalize_to_list(maybe_multiple): return maybe_multiple if isinstance(maybe_multiple, tuple) else tuple([maybe_multiple]) def __format_classes(self, classes): if self.classes: return self.__normalize_to_list(itemgetter(*classes.tolist())(self.classes)) else: return classes.tolist() def __format_prediction(self, output: dict): boxes, classes = itemgetter("bboxes", "classes")(output) if len(boxes): boxes = self.__format_boxes(boxes) classes = self.__format_classes(classes) else: boxes, classes = [], [] output["bboxes"] = boxes output["classes"] = classes return output def __filter_predictions_for_image(self, predictions): boxes, classes = itemgetter("bboxes", "classes")(predictions) if boxes: keep = map(lambda c: c != self.rejection_class, classes) compressed = list(compress(zip(boxes, classes), keep)) boxes, classes = map(list, zip(*compressed)) if compressed else ([], []) predictions["bboxes"] = boxes predictions["classes"] = classes return predictions def filter_predictions(self, predictions): def detections_present(_, prediction): return bool(prediction["classes"]) def build_return_dict(page_idx, predictions): return {"page_idx": page_idx, **predictions} filtered_rejections = map(self.__filter_predictions_for_image, predictions) filtered_no_detections = starfilter(detections_present, enumerate(filtered_rejections)) filtered_no_detections = starmap(build_return_dict, filtered_no_detections) return filtered_no_detections def format_predictions(self, outputs: Iterable): return map(self.__format_prediction, outputs) def predict(self, images, threshold=None, format_output=False): if not threshold: threshold = read_config("threshold") predictions = infer(images, self.model, read_config("device"), threshold) if format_output: predictions = self.format_predictions(predictions) if self.rejection_class: predictions = self.filter_predictions(predictions) return predictions