109 lines
2.4 KiB
Python
109 lines
2.4 KiB
Python
import argparse
|
|
import logging
|
|
import os
|
|
|
|
from flask import Flask, request, jsonify
|
|
from pdf2image import pdf2image
|
|
|
|
from fb_detr.locations import DATA_DIR
|
|
from fb_detr.locations import TORCH_HOME
|
|
from fb_detr.predictor import Predictor
|
|
from fb_detr.utils.config import read_config
|
|
|
|
|
|
def suppress_userwarnings():
|
|
import warnings
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--resume")
|
|
parser.add_argument("--warnings", action="store_true", default=False)
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def load_classes():
|
|
classes = read_config("classes")
|
|
id2class = dict(zip(range(1, len(classes) + 1), classes))
|
|
return id2class
|
|
|
|
|
|
def get_checkpoint():
|
|
return DATA_DIR / read_config("checkpoint")
|
|
|
|
|
|
def set_torch_env():
|
|
os.environ["TORCH_HOME"] = str(TORCH_HOME)
|
|
|
|
|
|
def main(args):
|
|
|
|
if not args.warnings:
|
|
suppress_userwarnings()
|
|
|
|
run_server(args.resume)
|
|
|
|
|
|
def run_server(resume):
|
|
|
|
set_torch_env()
|
|
|
|
def initialize_predictor():
|
|
checkpoint = get_checkpoint() if not resume else resume
|
|
predictor = Predictor(checkpoint, classes=load_classes(), rejection_class=read_config("rejection_class"))
|
|
return predictor
|
|
|
|
app = Flask(__name__)
|
|
|
|
@app.route("/ready", methods=["GET"])
|
|
def ready():
|
|
resp = jsonify("OK")
|
|
resp.status_code = 200
|
|
return resp
|
|
|
|
@app.route("/health", methods=["GET"])
|
|
def healthy():
|
|
resp = jsonify("OK")
|
|
resp.status_code = 200
|
|
return resp
|
|
|
|
@app.route("/", methods=["POST"])
|
|
def predict_request():
|
|
def inner():
|
|
|
|
pdf = request.data
|
|
|
|
pages = pdf2image.convert_from_bytes(pdf)
|
|
predictions = predictor.predict(pages)
|
|
|
|
return jsonify(list(predictions))
|
|
|
|
try:
|
|
return inner()
|
|
except Exception as err:
|
|
logging.warning("Analysis failed")
|
|
logging.exception(err)
|
|
resp = jsonify("Analysis failed")
|
|
resp.status_code = 500
|
|
return resp
|
|
|
|
@app.route("/status", methods=["GET"])
|
|
def status():
|
|
response = "OK"
|
|
return jsonify(response)
|
|
|
|
predictor = initialize_predictor()
|
|
|
|
logging.info("Predictor ready.")
|
|
|
|
app.run(host="127.0.0.1", port=5000, debug=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
main(args)
|