2024-10-23 16:06:06 +02:00

97 lines
3.1 KiB
Python

import asyncio
import inspect
import logging
import signal
import threading
import time
from typing import Callable
import uvicorn
from dynaconf import Dynaconf
from fastapi import FastAPI
from kn_utils.logging import logger
from kn_utils.retry import retry
from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import webserver_validators
@retry(
tries=5,
exceptions=Exception,
reraise=True,
)
def create_webserver_thread_from_settings(app: FastAPI, settings: Dynaconf) -> threading.Thread:
validate_settings(settings, validators=webserver_validators)
return create_webserver_thread(app=app, port=settings.webserver.port, host=settings.webserver.host)
def create_webserver_thread(app: FastAPI, port: int, host: str) -> threading.Thread:
"""Creates a thread that runs a FastAPI webserver. Start with thread.start(), and join with thread.join().
Note that the thread is a daemon thread, so it will be terminated when the main thread is terminated.
"""
def run_server():
retries = 5
for attempt in range(retries):
try:
uvicorn.run(app, port=port, host=host, log_level=logging.WARNING)
break
except Exception as e:
if attempt < retries - 1: # if it's not the last attempt
logger.warning(f"Attempt {attempt + 1} failed to start the server: {e}. Retrying...")
time.sleep(2**attempt) # exponential backoff
else:
logger.error(f"Failed to start the server after {retries} attempts: {e}")
raise
thread = threading.Thread(target=run_server)
thread.daemon = True
return thread
async def run_async_webserver(app: FastAPI, port: int, host: str):
"""Run the FastAPI web server async."""
config = uvicorn.Config(app, host=host, port=port, log_level=logging.WARNING)
server = uvicorn.Server(config)
try:
await server.serve()
except asyncio.CancelledError:
logger.info("Webserver was cancelled.")
server.should_exit = True
await server.shutdown()
except Exception as e:
logger.error(f"Error while running the webserver: {e}", exc_info=True)
finally:
logger.info("Webserver has been shut down.")
HealthFunction = Callable[[], bool]
def add_health_check_endpoint(app: FastAPI, health_function: HealthFunction) -> FastAPI:
"""Add a health check endpoint to the app. The health function should return True if the service is healthy,
and False otherwise. The health function is called when the endpoint is hit.
"""
if inspect.iscoroutinefunction(health_function):
@app.get("/health")
@app.get("/ready")
async def async_check_health():
alive = await health_function()
if alive:
return {"status": "OK"}, 200
return {"status": "Service Unavailable"}, 503
else:
@app.get("/health")
@app.get("/ready")
def check_health():
if health_function():
return {"status": "OK"}, 200
return {"status": "Service Unavailable"}, 503
return app