diff --git a/pyinfra/monitor/__init__.py b/pyinfra/monitor/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/pyinfra/monitor/prometheus.py b/pyinfra/webserver/prometheus.py similarity index 100% rename from pyinfra/monitor/prometheus.py rename to pyinfra/webserver/prometheus.py diff --git a/pyinfra/webserver/utils.py b/pyinfra/webserver/utils.py index 70a9c2b..fc7534b 100644 --- a/pyinfra/webserver/utils.py +++ b/pyinfra/webserver/utils.py @@ -1,5 +1,6 @@ import logging import threading +from typing import Callable import uvicorn from dynaconf import Dynaconf @@ -21,3 +22,22 @@ def create_webserver_thread(app: FastAPI, port: int, host: str) -> threading.Thr thread = threading.Thread(target=lambda: uvicorn.run(app, port=port, host=host, log_level=logging.WARNING)) thread.daemon = True return thread + + +HealthFunction = Callable[[], bool] + + +def add_health_check_endpoint(app: FastAPI, health_function: HealthFunction) -> FastAPI: + """Add a health check endpoint to the app. The health function should return True if the service is healthy, + and False otherwise. The health function is called when the endpoint is hit. + """ + + @app.get("/health") + @app.get("/ready") + def check_health(): + if health_function(): + return {"status": "OK"}, 200 + else: + return {"status": "Service Unavailable"}, 503 + + return app diff --git a/scripts/start_pyinfra.py b/scripts/start_pyinfra.py index ea959b5..86fca6c 100644 --- a/scripts/start_pyinfra.py +++ b/scripts/start_pyinfra.py @@ -4,10 +4,10 @@ import time from fastapi import FastAPI from pyinfra.config.loader import load_settings -from pyinfra.monitor.prometheus import make_prometheus_processing_time_decorator_from_settings, add_prometheus_endpoint +from pyinfra.webserver.prometheus import make_prometheus_processing_time_decorator_from_settings, add_prometheus_endpoint from pyinfra.queue.callback import make_queue_message_callback from pyinfra.queue.manager import QueueManager -from pyinfra.webserver.utils import create_webserver_thread_from_settings +from pyinfra.webserver.utils import create_webserver_thread_from_settings, add_health_check_endpoint logging.basicConfig() logger = logging.getLogger() @@ -28,13 +28,7 @@ def main(): queue_manager = QueueManager(settings) - @app.get("/ready") - @app.get("/health") - def check_health(): - if queue_manager.is_ready(): - return {"status": "OK"}, 200 - else: - return {"status": "Service Unavailable"}, 503 + app = add_health_check_endpoint(app, queue_manager.is_ready) webserver_thread = create_webserver_thread_from_settings(app, settings) webserver_thread.start() diff --git a/tests/unit_test/prometheus_monitoring_test.py b/tests/unit_test/prometheus_monitoring_test.py index 452954c..3ced056 100644 --- a/tests/unit_test/prometheus_monitoring_test.py +++ b/tests/unit_test/prometheus_monitoring_test.py @@ -5,7 +5,7 @@ import pytest import requests from fastapi import FastAPI -from pyinfra.monitor.prometheus import add_prometheus_endpoint, make_prometheus_processing_time_decorator_from_settings +from pyinfra.webserver.prometheus import add_prometheus_endpoint, make_prometheus_processing_time_decorator_from_settings from pyinfra.webserver.utils import create_webserver_thread_from_settings diff --git a/tests/unit_test/queue_test.py b/tests/unit_test/queue_test.py index 449cf30..720c6c3 100644 --- a/tests/unit_test/queue_test.py +++ b/tests/unit_test/queue_test.py @@ -15,7 +15,7 @@ logger.add(sink=stdout, level="DEBUG") def make_callback(process_time): def callback(x): sleep(process_time) - return json.dumps({"status": "success"}) + return {"status": "success"} return callback @@ -56,7 +56,7 @@ class TestQueueManager: for _ in range(2): response = queue_manager.get_message_from_output_queue() assert response is not None - assert response[2] == b'{"status": "success"}' + assert json.loads(response[2].decode()) == {"status": "success"} def test_all_headers_beginning_with_x_are_forwarded(self, queue_manager, input_message, stop_message): queue_manager.purge_queues() @@ -78,7 +78,7 @@ class TestQueueManager: response = queue_manager.get_message_from_output_queue() - assert response[2] == b'{"status": "success"}' + assert json.loads(response[2].decode()) == {"status": "success"} assert response[1].headers["X-TENANT-ID"] == "redaction" assert response[1].headers["X-OTHER-HEADER"] == "other-header-value" @@ -97,4 +97,4 @@ class TestQueueManager: response = queue_manager.get_message_from_output_queue() - assert response[2] == b'{"status": "success"}' + assert json.loads(response[2].decode()) == {"status": "success"}