diff --git a/pyinfra/webserver/utils.py b/pyinfra/webserver/utils.py index 710c26a..9a2a438 100644 --- a/pyinfra/webserver/utils.py +++ b/pyinfra/webserver/utils.py @@ -1,3 +1,4 @@ +import inspect import logging import threading from typing import Callable @@ -31,13 +32,23 @@ def add_health_check_endpoint(app: FastAPI, health_function: HealthFunction) -> """Add a health check endpoint to the app. The health function should return True if the service is healthy, and False otherwise. The health function is called when the endpoint is hit. """ + if inspect.iscoroutinefunction(health_function): - @app.get("/health") - @app.get("/ready") - def check_health(): - if health_function(): - return {"status": "OK"}, 200 - else: + @app.get("/health") + @app.get("/ready") + async def async_check_health(): + alive = await health_function() + if alive: + return {"status": "OK"}, 200 + return {"status": "Service Unavailable"}, 503 + + else: + + @app.get("/health") + @app.get("/ready") + def check_health(): + if health_function(): + return {"status": "OK"}, 200 return {"status": "Service Unavailable"}, 503 return app