diff --git a/incl/detr b/incl/detr index c17cddd..0d2808c 160000 --- a/incl/detr +++ b/incl/detr @@ -1 +1 @@ -Subproject commit c17cddd980ae3003a2633a65744d2265228e4c71 +Subproject commit 0d2808c70737fb9fed665334c4cdec7fd39b2e4b diff --git a/requirements.txt b/requirements.txt index 250550b..e28f04f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ iteration-utilities==0.11.0 dvc==2.9.3 dvc[ssh] frozendict==2.3.0 +waitress==2.0.0 diff --git a/setup/docker.sh b/setup/docker.sh index 22a492d..93af6fd 100755 --- a/setup/docker.sh +++ b/setup/docker.sh @@ -12,4 +12,4 @@ dvc pull git submodule update --init --recursive docker build -f Dockerfile-base -t detr-server-base . -docker build -f Dockerfile -t detr-server . --build-arg +docker build -f Dockerfile -t detr-server . diff --git a/src/run_service.py b/src/run_service.py index 25f3720..17f7492 100644 --- a/src/run_service.py +++ b/src/run_service.py @@ -4,14 +4,21 @@ import os from fb_detr.locations import DATA_DIR from fb_detr.locations import TORCH_HOME from fb_detr.predictor import Predictor +from waitress import serve from flask import Flask, request, jsonify from pdf2image import pdf2image from fb_detr.utils.config import read_config +def suppress_userwarnings(): + import warnings + warnings.filterwarnings("ignore") + + def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--resume") + parser.add_argument("--warnings", action="store_true", default=False) args = parser.parse_args() return args @@ -32,10 +39,19 @@ def set_torch_env(): def main(args): + + if not args.warnings: + suppress_userwarnings() + + run_server(args.resume) + + +def run_server(resume): + set_torch_env() def initialize_predictor(): - checkpoint = get_checkpoint() if not args.resume else args.resume + checkpoint = get_checkpoint() if not resume else resume predictor = Predictor(checkpoint, classes=load_classes(), rejection_class=read_config("rejection_class")) return predictor @@ -58,7 +74,7 @@ def main(args): predictor = initialize_predictor() - app.run(host="127.0.0.1", port=5000) + serve(app, host="127.0.0.1", port=5000) if __name__ == "__main__":