web server refactoring + tests

This commit is contained in:
Matthias Bisping 2022-03-31 23:43:14 +02:00
parent dd007891c7
commit c125e1ff6c
6 changed files with 91 additions and 21 deletions

View File

@ -13,6 +13,7 @@ omit =
*/build_env/* */build_env/*
*/utils/banner.py */utils/banner.py
*/utils/logger.py */utils/logger.py
*/src/*
source = source =
image_prediction image_prediction
src src
@ -48,6 +49,7 @@ omit =
*/build_env/* */build_env/*
*/utils/banner.py */utils/banner.py
*/utils/logger.py */utils/logger.py
*/src/*
ignore_errors = True ignore_errors = True

View File

@ -2,14 +2,18 @@ import multiprocessing
from typing import Callable from typing import Callable
from flask import Flask, request, jsonify from flask import Flask, request, jsonify
from waitress import serve
from image_prediction.utils import get_logger from image_prediction.utils import get_logger
logger = get_logger() logger = get_logger()
def make_prediction_server(predict_fn: Callable): def run_prediction_server(app, host, port):
serve(app, host=host, port=port, _quiet=False)
def make_prediction_server(predict_fn: Callable):
app = Flask(__name__) app = Flask(__name__)
@app.route("/ready", methods=["GET"]) @app.route("/ready", methods=["GET"])
@ -49,11 +53,11 @@ def make_prediction_server(predict_fn: Callable):
except KeyError: except KeyError:
raise raise
logger.debug("Running classifier on document...") logger.info("Analysing document...")
try: try:
predictions = process() predictions = process()
response = jsonify(predictions) response = jsonify(predictions)
logger.info("Analysis completed.") logger.debug("Analysis completed.")
return response return response
except Exception as err: except Exception as err:
logger.error("Analysis failed.") logger.error("Analysis failed.")

View File

@ -23,4 +23,3 @@ def make_logger_getter():
get_logger = make_logger_getter() get_logger = make_logger_getter()
1

View File

@ -1,10 +1,9 @@
import logging import logging
from waitress import serve
from image_prediction.config import CONFIG from image_prediction.config import CONFIG
from image_prediction.default_objects import load_pipeline from image_prediction.default_objects import load_pipeline
from image_prediction.flask import make_prediction_server from image_prediction.flask import make_prediction_server
from image_prediction.flask import run_prediction_server
from image_prediction.utils import get_logger from image_prediction.utils import get_logger
from image_prediction.utils.banner import show_banner from image_prediction.utils.banner import show_banner
@ -22,25 +21,17 @@ def main():
return pipeline(pdf) return pipeline(pdf)
prediction_server = make_prediction_server(predict) prediction_server = make_prediction_server(predict)
run_prediction_server(prediction_server, CONFIG.webserver.host, CONFIG.webserver.port, CONFIG.webserver.mode)
run_prediction_server(prediction_server, mode=CONFIG.webserver.mode)
def run_prediction_server(app, mode="development"):
if mode == "development":
app.run(host=CONFIG.webserver.host, port=CONFIG.webserver.port, debug=True)
elif mode == "production":
serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port, _quiet=False)
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=CONFIG.service.logging_level) logging.basicConfig(level=CONFIG.service.logging_level)
logging.getLogger("flask").setLevel(logging.ERROR) # logging.getLogger("flask").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.ERROR) # logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("werkzeug").setLevel(logging.ERROR) # logging.getLogger("werkzeug").setLevel(logging.ERROR)
logging.getLogger("waitress").setLevel(logging.ERROR) # logging.getLogger("waitress").setLevel(logging.ERROR)
logging.getLogger("PIL").setLevel(logging.ERROR) # logging.getLogger("PIL").setLevel(logging.ERROR)
logging.getLogger("h5py").setLevel(logging.ERROR) # logging.getLogger("h5py").setLevel(logging.ERROR)
show_banner() show_banner()

View File

@ -0,0 +1,74 @@
from multiprocessing import Process
from time import sleep
import pytest
import requests
from image_prediction.flask import make_prediction_server, run_prediction_server
@pytest.fixture
def host():
return "127.0.0.1"
@pytest.fixture
def port(host):
import socket
sock = socket.socket()
sock.bind((host, 0))
return sock.getsockname()[1]
@pytest.fixture
def url(host, port):
return f"http://{host}:{port}"
@pytest.fixture
def predict_fn():
def predict(_):
return 42
return predict
@pytest.fixture
def run_server_args(host, port, predict_fn):
prediction_server = make_prediction_server(predict_fn)
return {"app": prediction_server, "host": host, "port": port}
@pytest.fixture
def server(run_server_args):
def get_server_process():
return Process(target=run_prediction_server, kwargs=run_server_args)
server = get_server_process()
server.start()
yield
server.terminate()
# def test_run_server():
# prediction_server = make_prediction_server(predict_fn)
# return Process(target=run_prediction_server, kwargs={"app": prediction_server, "host": host, "port": port})
def test_server_predict(server, url):
response = requests.post(url)
response.raise_for_status()
assert response.json() == 42
def test_server_health_check(server, url):
response = requests.get(f"{url}/health")
response.raise_for_status()
assert response.status_code == 200
def test_server_ready_check(server, url):
response = requests.get(f"{url}/ready")
response.raise_for_status()
assert response.status_code == 200