109 lines
2.4 KiB
Python

import argparse
import logging
import os
from flask import Flask, request, jsonify
from pdf2image import pdf2image
from fb_detr.locations import DATA_DIR
from fb_detr.locations import TORCH_HOME
from fb_detr.predictor import Predictor
from fb_detr.utils.config import read_config
def suppress_userwarnings():
import warnings
warnings.filterwarnings("ignore")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--resume")
parser.add_argument("--warnings", action="store_true", default=False)
args = parser.parse_args()
return args
def load_classes():
classes = read_config("classes")
id2class = dict(zip(range(1, len(classes) + 1), classes))
return id2class
def get_checkpoint():
return DATA_DIR / read_config("checkpoint")
def set_torch_env():
os.environ["TORCH_HOME"] = str(TORCH_HOME)
def main(args):
if not args.warnings:
suppress_userwarnings()
run_server(args.resume)
def run_server(resume):
set_torch_env()
def initialize_predictor():
checkpoint = get_checkpoint() if not resume else resume
predictor = Predictor(checkpoint, classes=load_classes(), rejection_class=read_config("rejection_class"))
return predictor
app = Flask(__name__)
@app.route("/ready", methods=["GET"])
def ready():
resp = jsonify("OK")
resp.status_code = 200
return resp
@app.route("/health", methods=["GET"])
def healthy():
resp = jsonify("OK")
resp.status_code = 200
return resp
@app.route("/", methods=["POST"])
def predict_request():
def inner():
pdf = request.data
pages = pdf2image.convert_from_bytes(pdf)
predictions = predictor.predict(pages)
return jsonify(list(predictions))
try:
return inner()
except Exception as err:
logging.warning("Analysis failed")
logging.exception(err)
resp = jsonify("Analysis failed")
resp.status_code = 500
return resp
@app.route("/status", methods=["GET"])
def status():
response = "OK"
return jsonify(response)
predictor = initialize_predictor()
logging.info("Predictor ready.")
app.run(host="127.0.0.1", port=5000, debug=True)
if __name__ == "__main__":
args = parse_args()
main(args)