From 9d4ec84b491d55ec1473b152da7fdc4ff6f3593d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20K=C3=B6ssler?= Date: Wed, 13 Nov 2024 13:54:41 +0100 Subject: [PATCH] fix: use signals for graceful shutdown --- pyinfra/examples.py | 48 +++++++++++++++++++++++++++++++------- pyinfra/webserver/utils.py | 9 ++++++- scripts/send_sigterm.py | 17 ++++++++++++++ 3 files changed, 64 insertions(+), 10 deletions(-) create mode 100644 scripts/send_sigterm.py diff --git a/pyinfra/examples.py b/pyinfra/examples.py index adbd879..036cbe5 100644 --- a/pyinfra/examples.py +++ b/pyinfra/examples.py @@ -1,4 +1,5 @@ import asyncio +import signal import sys import aiohttp @@ -22,12 +23,37 @@ from pyinfra.webserver.utils import ( run_async_webserver, ) +shutdown_flag = False + + +async def graceful_shutdown(manager, queue_task, webserver_task): + global shutdown_flag + shutdown_flag = True + logger.info("SIGTERM received, shutting down gracefully...") + if queue_task and not queue_task.done(): + queue_task.cancel() + await manager.shutdown() + if webserver_task and not webserver_task.done(): + webserver_task.cancel() + await asyncio.gather(queue_task, webserver_task, return_exceptions=True) + logger.info("Shutdown complete.") + async def run_async_queues(manager: AsyncQueueManager, app, port, host): """Run the async webserver and the async queue manager concurrently.""" queue_task = None webserver_task = None tenant_api_available = True + + # add signal handler for SIGTERM and SIGINT + loop = asyncio.get_running_loop() + loop.add_signal_handler( + signal.SIGTERM, lambda: asyncio.create_task(graceful_shutdown(manager, queue_task, webserver_task)) + ) + loop.add_signal_handler( + signal.SIGINT, lambda: asyncio.create_task(graceful_shutdown(manager, queue_task, webserver_task)) + ) + try: active_tenants = await manager.fetch_active_tenants() @@ -45,17 +71,21 @@ async def run_async_queues(manager: AsyncQueueManager, app, port, host): logger.error(f"An error occurred while running async queues: {e}", exc_info=True) sys.exit(1) finally: - logger.info("Signal received, shutting down...") - if not tenant_api_available: - sys.exit(0) - if queue_task and not queue_task.done(): - queue_task.cancel() - if webserver_task and not webserver_task.done(): - webserver_task.cancel() + if shutdown_flag: + logger.info("Graceful shutdown already in progress.") + else: + logger.warning("Initiating shutdown due to error or manual interruption.") + if not tenant_api_available: + sys.exit(0) + if queue_task and not queue_task.done(): + queue_task.cancel() + await manager.shutdown() - await manager.shutdown() + if webserver_task and not webserver_task.done(): + webserver_task.cancel() - await asyncio.gather(queue_task, webserver_task, return_exceptions=True) + await asyncio.gather(queue_task, webserver_task, return_exceptions=True) + logger.info("Shutdown complete.") def start_standard_queue_consumer( diff --git a/pyinfra/webserver/utils.py b/pyinfra/webserver/utils.py index c0607b1..0db6048 100644 --- a/pyinfra/webserver/utils.py +++ b/pyinfra/webserver/utils.py @@ -16,6 +16,13 @@ from pyinfra.config.loader import validate_settings from pyinfra.config.validators import webserver_validators +class PyInfraUvicornServer(uvicorn.Server): + # this is a workaround to enable custom signal handlers + # https://github.com/encode/uvicorn/issues/1579 + def install_signal_handlers(self): + pass + + @retry( tries=5, exceptions=Exception, @@ -53,7 +60,7 @@ def create_webserver_thread(app: FastAPI, port: int, host: str) -> threading.Thr async def run_async_webserver(app: FastAPI, port: int, host: str): """Run the FastAPI web server async.""" config = uvicorn.Config(app, host=host, port=port, log_level=logging.WARNING) - server = uvicorn.Server(config) + server = PyInfraUvicornServer(config) try: await server.serve() diff --git a/scripts/send_sigterm.py b/scripts/send_sigterm.py new file mode 100644 index 0000000..6f1ce52 --- /dev/null +++ b/scripts/send_sigterm.py @@ -0,0 +1,17 @@ +import os +import signal +import time + +# BE CAREFUL WITH THIS SCRIPT - THIS SIMULATES A SIGTERM FROM KUBERNETES +target_pid = int(input("Enter the PID of the target script: ")) + +print(f"Sending SIGTERM to PID {target_pid}...") +time.sleep(1) + +try: + os.kill(target_pid, signal.SIGTERM) + print("SIGTERM sent.") +except ProcessLookupError: + print("Process not found.") +except PermissionError: + print("Permission denied. Are you trying to signal a process you don't own?")