fix: use signals for graceful shutdown

This commit is contained in:
Jonathan Kössler 2024-11-13 13:54:41 +01:00
parent e51e5c33eb
commit 9d4ec84b49
3 changed files with 64 additions and 10 deletions

View File

@ -1,4 +1,5 @@
import asyncio
import signal
import sys
import aiohttp
@ -22,12 +23,37 @@ from pyinfra.webserver.utils import (
run_async_webserver,
)
shutdown_flag = False
async def graceful_shutdown(manager, queue_task, webserver_task):
global shutdown_flag
shutdown_flag = True
logger.info("SIGTERM received, shutting down gracefully...")
if queue_task and not queue_task.done():
queue_task.cancel()
await manager.shutdown()
if webserver_task and not webserver_task.done():
webserver_task.cancel()
await asyncio.gather(queue_task, webserver_task, return_exceptions=True)
logger.info("Shutdown complete.")
async def run_async_queues(manager: AsyncQueueManager, app, port, host):
"""Run the async webserver and the async queue manager concurrently."""
queue_task = None
webserver_task = None
tenant_api_available = True
# add signal handler for SIGTERM and SIGINT
loop = asyncio.get_running_loop()
loop.add_signal_handler(
signal.SIGTERM, lambda: asyncio.create_task(graceful_shutdown(manager, queue_task, webserver_task))
)
loop.add_signal_handler(
signal.SIGINT, lambda: asyncio.create_task(graceful_shutdown(manager, queue_task, webserver_task))
)
try:
active_tenants = await manager.fetch_active_tenants()
@ -45,17 +71,21 @@ async def run_async_queues(manager: AsyncQueueManager, app, port, host):
logger.error(f"An error occurred while running async queues: {e}", exc_info=True)
sys.exit(1)
finally:
logger.info("Signal received, shutting down...")
if not tenant_api_available:
sys.exit(0)
if queue_task and not queue_task.done():
queue_task.cancel()
if webserver_task and not webserver_task.done():
webserver_task.cancel()
if shutdown_flag:
logger.info("Graceful shutdown already in progress.")
else:
logger.warning("Initiating shutdown due to error or manual interruption.")
if not tenant_api_available:
sys.exit(0)
if queue_task and not queue_task.done():
queue_task.cancel()
await manager.shutdown()
await manager.shutdown()
if webserver_task and not webserver_task.done():
webserver_task.cancel()
await asyncio.gather(queue_task, webserver_task, return_exceptions=True)
await asyncio.gather(queue_task, webserver_task, return_exceptions=True)
logger.info("Shutdown complete.")
def start_standard_queue_consumer(

View File

@ -16,6 +16,13 @@ from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import webserver_validators
class PyInfraUvicornServer(uvicorn.Server):
# this is a workaround to enable custom signal handlers
# https://github.com/encode/uvicorn/issues/1579
def install_signal_handlers(self):
pass
@retry(
tries=5,
exceptions=Exception,
@ -53,7 +60,7 @@ def create_webserver_thread(app: FastAPI, port: int, host: str) -> threading.Thr
async def run_async_webserver(app: FastAPI, port: int, host: str):
"""Run the FastAPI web server async."""
config = uvicorn.Config(app, host=host, port=port, log_level=logging.WARNING)
server = uvicorn.Server(config)
server = PyInfraUvicornServer(config)
try:
await server.serve()

17
scripts/send_sigterm.py Normal file
View File

@ -0,0 +1,17 @@
import os
import signal
import time
# BE CAREFUL WITH THIS SCRIPT - THIS SIMULATES A SIGTERM FROM KUBERNETES
target_pid = int(input("Enter the PID of the target script: "))
print(f"Sending SIGTERM to PID {target_pid}...")
time.sleep(1)
try:
os.kill(target_pid, signal.SIGTERM)
print("SIGTERM sent.")
except ProcessLookupError:
print("Process not found.")
except PermissionError:
print("Permission denied. Are you trying to signal a process you don't own?")