import argparse import logging import os from flask import Flask, request, jsonify from pdf2image import pdf2image from fb_detr.locations import DATA_DIR from fb_detr.locations import TORCH_HOME from fb_detr.predictor import Predictor from fb_detr.utils.config import read_config def suppress_userwarnings(): import warnings warnings.filterwarnings("ignore") def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--resume") parser.add_argument("--warnings", action="store_true", default=False) args = parser.parse_args() return args def load_classes(): classes = read_config("classes") id2class = dict(zip(range(1, len(classes) + 1), classes)) return id2class def get_checkpoint(): return DATA_DIR / read_config("checkpoint") def set_torch_env(): os.environ["TORCH_HOME"] = str(TORCH_HOME) def main(args): if not args.warnings: suppress_userwarnings() run_server(args.resume) def run_server(resume): set_torch_env() def initialize_predictor(): checkpoint = get_checkpoint() if not resume else resume predictor = Predictor(checkpoint, classes=load_classes(), rejection_class=read_config("rejection_class")) return predictor app = Flask(__name__) @app.route("/ready", methods=["GET"]) def ready(): resp = jsonify("OK") resp.status_code = 200 return resp @app.route("/health", methods=["GET"]) def healthy(): resp = jsonify("OK") resp.status_code = 200 return resp @app.route("/", methods=["POST"]) def predict_request(): def inner(): pdf = request.data pages = pdf2image.convert_from_bytes(pdf) predictions = predictor.predict(pages) return jsonify(list(predictions)) try: return inner() except Exception as err: logging.warning("Analysis failed") logging.exception(err) resp = jsonify("Analysis failed") resp.status_code = 500 return resp @app.route("/status", methods=["GET"]) def status(): response = "OK" return jsonify(response) predictor = initialize_predictor() logging.info("Predictor ready.") app.run(host="127.0.0.1", port=5000, debug=True) if __name__ == "__main__": args = parse_args() main(args)