diff --git a/.coveragerc b/.coveragerc index 77e78ab..a868b80 100644 --- a/.coveragerc +++ b/.coveragerc @@ -13,6 +13,7 @@ omit = */build_env/* */utils/banner.py */utils/logger.py + */src/* source = image_prediction src @@ -48,6 +49,7 @@ omit = */build_env/* */utils/banner.py */utils/logger.py + */src/* ignore_errors = True diff --git a/image_prediction/flask.py b/image_prediction/flask.py index 19a0e8c..8cd5111 100644 --- a/image_prediction/flask.py +++ b/image_prediction/flask.py @@ -2,14 +2,18 @@ import multiprocessing from typing import Callable from flask import Flask, request, jsonify +from waitress import serve from image_prediction.utils import get_logger logger = get_logger() -def make_prediction_server(predict_fn: Callable): +def run_prediction_server(app, host, port): + serve(app, host=host, port=port, _quiet=False) + +def make_prediction_server(predict_fn: Callable): app = Flask(__name__) @app.route("/ready", methods=["GET"]) @@ -49,11 +53,11 @@ def make_prediction_server(predict_fn: Callable): except KeyError: raise - logger.debug("Running classifier on document...") + logger.info("Analysing document...") try: predictions = process() response = jsonify(predictions) - logger.info("Analysis completed.") + logger.debug("Analysis completed.") return response except Exception as err: logger.error("Analysis failed.") diff --git a/image_prediction/utils/logger.py b/image_prediction/utils/logger.py index 4f5186f..c712da4 100644 --- a/image_prediction/utils/logger.py +++ b/image_prediction/utils/logger.py @@ -23,4 +23,3 @@ def make_logger_getter(): get_logger = make_logger_getter() -1 \ No newline at end of file diff --git a/src/serve.py b/src/serve.py index c0de75f..bd4fece 100644 --- a/src/serve.py +++ b/src/serve.py @@ -1,10 +1,9 @@ import logging -from waitress import serve - from image_prediction.config import CONFIG from image_prediction.default_objects import load_pipeline from image_prediction.flask import make_prediction_server +from image_prediction.flask import run_prediction_server from image_prediction.utils import get_logger from image_prediction.utils.banner import show_banner @@ -22,25 +21,17 @@ def main(): return pipeline(pdf) prediction_server = make_prediction_server(predict) - - run_prediction_server(prediction_server, mode=CONFIG.webserver.mode) - - -def run_prediction_server(app, mode="development"): - if mode == "development": - app.run(host=CONFIG.webserver.host, port=CONFIG.webserver.port, debug=True) - elif mode == "production": - serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port, _quiet=False) + run_prediction_server(prediction_server, CONFIG.webserver.host, CONFIG.webserver.port, CONFIG.webserver.mode) if __name__ == "__main__": logging.basicConfig(level=CONFIG.service.logging_level) - logging.getLogger("flask").setLevel(logging.ERROR) - logging.getLogger("urllib3").setLevel(logging.ERROR) - logging.getLogger("werkzeug").setLevel(logging.ERROR) - logging.getLogger("waitress").setLevel(logging.ERROR) - logging.getLogger("PIL").setLevel(logging.ERROR) - logging.getLogger("h5py").setLevel(logging.ERROR) + # logging.getLogger("flask").setLevel(logging.ERROR) + # logging.getLogger("urllib3").setLevel(logging.ERROR) + # logging.getLogger("werkzeug").setLevel(logging.ERROR) + # logging.getLogger("waitress").setLevel(logging.ERROR) + # logging.getLogger("PIL").setLevel(logging.ERROR) + # logging.getLogger("h5py").setLevel(logging.ERROR) show_banner() diff --git a/test/integration_tests/server_test.py b/test/integration_tests/server_test.py new file mode 100644 index 0000000..8951940 --- /dev/null +++ b/test/integration_tests/server_test.py @@ -0,0 +1,74 @@ +from multiprocessing import Process +from time import sleep + +import pytest +import requests + +from image_prediction.flask import make_prediction_server, run_prediction_server + + +@pytest.fixture +def host(): + return "127.0.0.1" + + +@pytest.fixture +def port(host): + import socket + sock = socket.socket() + sock.bind((host, 0)) + return sock.getsockname()[1] + + +@pytest.fixture +def url(host, port): + return f"http://{host}:{port}" + + +@pytest.fixture +def predict_fn(): + def predict(_): + return 42 + + return predict + + +@pytest.fixture +def run_server_args(host, port, predict_fn): + prediction_server = make_prediction_server(predict_fn) + return {"app": prediction_server, "host": host, "port": port} + + +@pytest.fixture +def server(run_server_args): + + def get_server_process(): + return Process(target=run_prediction_server, kwargs=run_server_args) + + server = get_server_process() + server.start() + yield + server.terminate() + + +# def test_run_server(): +# prediction_server = make_prediction_server(predict_fn) +# return Process(target=run_prediction_server, kwargs={"app": prediction_server, "host": host, "port": port}) + + +def test_server_predict(server, url): + response = requests.post(url) + response.raise_for_status() + assert response.json() == 42 + + +def test_server_health_check(server, url): + response = requests.get(f"{url}/health") + response.raise_for_status() + assert response.status_code == 200 + + +def test_server_ready_check(server, url): + response = requests.get(f"{url}/ready") + response.raise_for_status() + assert response.status_code == 200 diff --git a/test/unit_tests/__init__.py b/test/unit_tests/__init__.py deleted file mode 100644 index e69de29..0000000