diff --git a/image_prediction/flask.py b/image_prediction/flask.py index 0e9c448..7f76775 100644 --- a/image_prediction/flask.py +++ b/image_prediction/flask.py @@ -3,17 +3,12 @@ import traceback from typing import Callable from flask import Flask, request, jsonify -from waitress import serve from image_prediction.utils import get_logger logger = get_logger() -def run_prediction_server(app, host, port): - serve(app, host=host, port=port, _quiet=False) - - def run_in_process(func): p = multiprocessing.Process(target=func) p.start() diff --git a/src/serve.py b/src/serve.py index 84808f1..59c2d15 100644 --- a/src/serve.py +++ b/src/serve.py @@ -1,9 +1,10 @@ import logging +from waitress import serve + from image_prediction.config import CONFIG from image_prediction.default_objects import load_pipeline from image_prediction.flask import make_prediction_server -from image_prediction.flask import run_prediction_server from image_prediction.utils import get_logger from image_prediction.utils.banner import show_banner @@ -21,7 +22,7 @@ def main(): return pipeline(pdf) prediction_server = make_prediction_server(predict) - run_prediction_server(prediction_server, CONFIG.webserver.host, CONFIG.webserver.port) + serve(prediction_server, host=CONFIG.webserver.host, port=CONFIG.webserver.port, _quiet=False) if __name__ == "__main__": diff --git a/test/integration_tests/server_test.py b/test/integration_tests/server_test.py index f7ef3ae..e215f1f 100644 --- a/test/integration_tests/server_test.py +++ b/test/integration_tests/server_test.py @@ -4,8 +4,9 @@ from multiprocessing import Process import pytest import requests from funcy import retry +from waitress import serve -from image_prediction.flask import make_prediction_server, run_prediction_server +from image_prediction.flask import make_prediction_server @pytest.fixture @@ -50,7 +51,7 @@ def server_ready(url): @pytest.fixture(autouse=True, scope="function") def server_process(server, host_and_port, url): def get_server_process(): - return Process(target=run_prediction_server, kwargs={"app": server, **host_and_port}) + return Process(target=serve, kwargs={"app": server, **host_and_port}) server = get_server_process() server.start()