97 lines
3.1 KiB
Python
97 lines
3.1 KiB
Python
import asyncio
|
|
import inspect
|
|
import logging
|
|
import signal
|
|
import threading
|
|
import time
|
|
from typing import Callable
|
|
|
|
import uvicorn
|
|
from dynaconf import Dynaconf
|
|
from fastapi import FastAPI
|
|
from kn_utils.logging import logger
|
|
from kn_utils.retry import retry
|
|
|
|
from pyinfra.config.loader import validate_settings
|
|
from pyinfra.config.validators import webserver_validators
|
|
|
|
|
|
@retry(
|
|
tries=5,
|
|
exceptions=Exception,
|
|
reraise=True,
|
|
)
|
|
def create_webserver_thread_from_settings(app: FastAPI, settings: Dynaconf) -> threading.Thread:
|
|
validate_settings(settings, validators=webserver_validators)
|
|
return create_webserver_thread(app=app, port=settings.webserver.port, host=settings.webserver.host)
|
|
|
|
|
|
def create_webserver_thread(app: FastAPI, port: int, host: str) -> threading.Thread:
|
|
"""Creates a thread that runs a FastAPI webserver. Start with thread.start(), and join with thread.join().
|
|
Note that the thread is a daemon thread, so it will be terminated when the main thread is terminated.
|
|
"""
|
|
|
|
def run_server():
|
|
retries = 5
|
|
for attempt in range(retries):
|
|
try:
|
|
uvicorn.run(app, port=port, host=host, log_level=logging.WARNING)
|
|
break
|
|
except Exception as e:
|
|
if attempt < retries - 1: # if it's not the last attempt
|
|
logger.warning(f"Attempt {attempt + 1} failed to start the server: {e}. Retrying...")
|
|
time.sleep(2**attempt) # exponential backoff
|
|
else:
|
|
logger.error(f"Failed to start the server after {retries} attempts: {e}")
|
|
raise
|
|
|
|
thread = threading.Thread(target=run_server)
|
|
thread.daemon = True
|
|
return thread
|
|
|
|
|
|
async def run_async_webserver(app: FastAPI, port: int, host: str):
|
|
"""Run the FastAPI web server async."""
|
|
config = uvicorn.Config(app, host=host, port=port, log_level=logging.WARNING)
|
|
server = uvicorn.Server(config)
|
|
|
|
try:
|
|
await server.serve()
|
|
except asyncio.CancelledError:
|
|
logger.info("Webserver was cancelled.")
|
|
server.should_exit = True
|
|
await server.shutdown()
|
|
except Exception as e:
|
|
logger.error(f"Error while running the webserver: {e}", exc_info=True)
|
|
finally:
|
|
logger.info("Webserver has been shut down.")
|
|
|
|
|
|
HealthFunction = Callable[[], bool]
|
|
|
|
|
|
def add_health_check_endpoint(app: FastAPI, health_function: HealthFunction) -> FastAPI:
|
|
"""Add a health check endpoint to the app. The health function should return True if the service is healthy,
|
|
and False otherwise. The health function is called when the endpoint is hit.
|
|
"""
|
|
if inspect.iscoroutinefunction(health_function):
|
|
|
|
@app.get("/health")
|
|
@app.get("/ready")
|
|
async def async_check_health():
|
|
alive = await health_function()
|
|
if alive:
|
|
return {"status": "OK"}, 200
|
|
return {"status": "Service Unavailable"}, 503
|
|
|
|
else:
|
|
|
|
@app.get("/health")
|
|
@app.get("/ready")
|
|
def check_health():
|
|
if health_function():
|
|
return {"status": "OK"}, 200
|
|
return {"status": "Service Unavailable"}, 503
|
|
|
|
return app
|