web server refactoring + tests
This commit is contained in:
parent
dd007891c7
commit
c125e1ff6c
@ -13,6 +13,7 @@ omit =
|
|||||||
*/build_env/*
|
*/build_env/*
|
||||||
*/utils/banner.py
|
*/utils/banner.py
|
||||||
*/utils/logger.py
|
*/utils/logger.py
|
||||||
|
*/src/*
|
||||||
source =
|
source =
|
||||||
image_prediction
|
image_prediction
|
||||||
src
|
src
|
||||||
@ -48,6 +49,7 @@ omit =
|
|||||||
*/build_env/*
|
*/build_env/*
|
||||||
*/utils/banner.py
|
*/utils/banner.py
|
||||||
*/utils/logger.py
|
*/utils/logger.py
|
||||||
|
*/src/*
|
||||||
|
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|||||||
@ -2,14 +2,18 @@ import multiprocessing
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from flask import Flask, request, jsonify
|
from flask import Flask, request, jsonify
|
||||||
|
from waitress import serve
|
||||||
|
|
||||||
from image_prediction.utils import get_logger
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
def make_prediction_server(predict_fn: Callable):
|
def run_prediction_server(app, host, port):
|
||||||
|
serve(app, host=host, port=port, _quiet=False)
|
||||||
|
|
||||||
|
|
||||||
|
def make_prediction_server(predict_fn: Callable):
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
@app.route("/ready", methods=["GET"])
|
@app.route("/ready", methods=["GET"])
|
||||||
@ -49,11 +53,11 @@ def make_prediction_server(predict_fn: Callable):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
logger.debug("Running classifier on document...")
|
logger.info("Analysing document...")
|
||||||
try:
|
try:
|
||||||
predictions = process()
|
predictions = process()
|
||||||
response = jsonify(predictions)
|
response = jsonify(predictions)
|
||||||
logger.info("Analysis completed.")
|
logger.debug("Analysis completed.")
|
||||||
return response
|
return response
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.error("Analysis failed.")
|
logger.error("Analysis failed.")
|
||||||
|
|||||||
@ -23,4 +23,3 @@ def make_logger_getter():
|
|||||||
|
|
||||||
|
|
||||||
get_logger = make_logger_getter()
|
get_logger = make_logger_getter()
|
||||||
1
|
|
||||||
25
src/serve.py
25
src/serve.py
@ -1,10 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from waitress import serve
|
|
||||||
|
|
||||||
from image_prediction.config import CONFIG
|
from image_prediction.config import CONFIG
|
||||||
from image_prediction.default_objects import load_pipeline
|
from image_prediction.default_objects import load_pipeline
|
||||||
from image_prediction.flask import make_prediction_server
|
from image_prediction.flask import make_prediction_server
|
||||||
|
from image_prediction.flask import run_prediction_server
|
||||||
from image_prediction.utils import get_logger
|
from image_prediction.utils import get_logger
|
||||||
from image_prediction.utils.banner import show_banner
|
from image_prediction.utils.banner import show_banner
|
||||||
|
|
||||||
@ -22,25 +21,17 @@ def main():
|
|||||||
return pipeline(pdf)
|
return pipeline(pdf)
|
||||||
|
|
||||||
prediction_server = make_prediction_server(predict)
|
prediction_server = make_prediction_server(predict)
|
||||||
|
run_prediction_server(prediction_server, CONFIG.webserver.host, CONFIG.webserver.port, CONFIG.webserver.mode)
|
||||||
run_prediction_server(prediction_server, mode=CONFIG.webserver.mode)
|
|
||||||
|
|
||||||
|
|
||||||
def run_prediction_server(app, mode="development"):
|
|
||||||
if mode == "development":
|
|
||||||
app.run(host=CONFIG.webserver.host, port=CONFIG.webserver.port, debug=True)
|
|
||||||
elif mode == "production":
|
|
||||||
serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port, _quiet=False)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.basicConfig(level=CONFIG.service.logging_level)
|
logging.basicConfig(level=CONFIG.service.logging_level)
|
||||||
logging.getLogger("flask").setLevel(logging.ERROR)
|
# logging.getLogger("flask").setLevel(logging.ERROR)
|
||||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
# logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||||
logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
# logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
||||||
logging.getLogger("waitress").setLevel(logging.ERROR)
|
# logging.getLogger("waitress").setLevel(logging.ERROR)
|
||||||
logging.getLogger("PIL").setLevel(logging.ERROR)
|
# logging.getLogger("PIL").setLevel(logging.ERROR)
|
||||||
logging.getLogger("h5py").setLevel(logging.ERROR)
|
# logging.getLogger("h5py").setLevel(logging.ERROR)
|
||||||
|
|
||||||
show_banner()
|
show_banner()
|
||||||
|
|
||||||
|
|||||||
74
test/integration_tests/server_test.py
Normal file
74
test/integration_tests/server_test.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from multiprocessing import Process
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from image_prediction.flask import make_prediction_server, run_prediction_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def host():
|
||||||
|
return "127.0.0.1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def port(host):
|
||||||
|
import socket
|
||||||
|
sock = socket.socket()
|
||||||
|
sock.bind((host, 0))
|
||||||
|
return sock.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def url(host, port):
|
||||||
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def predict_fn():
|
||||||
|
def predict(_):
|
||||||
|
return 42
|
||||||
|
|
||||||
|
return predict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def run_server_args(host, port, predict_fn):
|
||||||
|
prediction_server = make_prediction_server(predict_fn)
|
||||||
|
return {"app": prediction_server, "host": host, "port": port}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def server(run_server_args):
|
||||||
|
|
||||||
|
def get_server_process():
|
||||||
|
return Process(target=run_prediction_server, kwargs=run_server_args)
|
||||||
|
|
||||||
|
server = get_server_process()
|
||||||
|
server.start()
|
||||||
|
yield
|
||||||
|
server.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
# def test_run_server():
|
||||||
|
# prediction_server = make_prediction_server(predict_fn)
|
||||||
|
# return Process(target=run_prediction_server, kwargs={"app": prediction_server, "host": host, "port": port})
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_predict(server, url):
|
||||||
|
response = requests.post(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
assert response.json() == 42
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_health_check(server, url):
|
||||||
|
response = requests.get(f"{url}/health")
|
||||||
|
response.raise_for_status()
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_ready_check(server, url):
|
||||||
|
response = requests.get(f"{url}/ready")
|
||||||
|
response.raise_for_status()
|
||||||
|
assert response.status_code == 200
|
||||||
Loading…
x
Reference in New Issue
Block a user