From e2edfa7260e362cc0dec8f7531fe66d30fe38560 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20K=C3=B6ssler?= Date: Thu, 26 Sep 2024 10:33:05 +0200 Subject: [PATCH 1/5] fix: simplify webserver shutdown --- pyinfra/webserver/utils.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/pyinfra/webserver/utils.py b/pyinfra/webserver/utils.py index 5debb15..b87f380 100644 --- a/pyinfra/webserver/utils.py +++ b/pyinfra/webserver/utils.py @@ -4,13 +4,18 @@ import logging import signal import threading import time -from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type from typing import Callable import uvicorn from dynaconf import Dynaconf from fastapi import FastAPI from kn_utils.logging import logger +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from pyinfra.config.loader import validate_settings from pyinfra.config.validators import webserver_validators @@ -56,20 +61,12 @@ async def run_async_webserver(app: FastAPI, port: int, host: str): config = uvicorn.Config(app, host=host, port=port, log_level=logging.WARNING) server = uvicorn.Server(config) - async def shutdown(signal): - logger.info(f"Received signal {signal.name}, shutting down webserver...") - await app.shutdown() - await app.cleanup() - logger.info("Shutdown complete.") - - loop = asyncio.get_event_loop() - for sig in (signal.SIGTERM, signal.SIGINT): - loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(shutdown(s))) - try: await server.serve() except asyncio.CancelledError: - pass + logger.info("Webserver was cancelled.") + finally: + logger.info("Webserver has been shut down.") HealthFunction = Callable[[], bool] From 4119a7d7d71462c035e254a93a5ebd63bc18d611 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20K=C3=B6ssler?= Date: Thu, 26 Sep 2024 11:05:12 +0200 Subject: [PATCH 2/5] chore: bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a39299b..f4aee89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pyinfra" -version = "3.2.10" +version = "3.2.11" description = "" authors = ["Team Research "] license = "All rights reseverd" From 541219177f98545f0e1accd9d2e450f994ad482b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20K=C3=B6ssler?= Date: Thu, 26 Sep 2024 12:28:55 +0200 Subject: [PATCH 3/5] feat: add error handling to shutdown logic --- pyinfra/examples.py | 52 +++++++++++++++++++++++----------- pyinfra/queue/async_manager.py | 38 ++++++++++++++++--------- pyinfra/webserver/utils.py | 4 +++ 3 files changed, 64 insertions(+), 30 deletions(-) diff --git a/pyinfra/examples.py b/pyinfra/examples.py index d7fde30..7aca2ab 100644 --- a/pyinfra/examples.py +++ b/pyinfra/examples.py @@ -33,34 +33,52 @@ from pyinfra.webserver.utils import ( retry=retry_if_exception_type((Exception,)), # You might want to be more specific here reraise=True, ) -async def run_async_queues(manager, app, port, host): +async def run_async_queues(manager: AsyncQueueManager, app, port, host): """Run the async webserver and the async queue manager concurrently.""" + queue_task = None + webserver_task = None try: - await manager.run() - await run_async_webserver(app, port, host) + queue_task = asyncio.create_task(manager.run()) + webserver_task = asyncio.create_task(run_async_webserver(app, port, host)) + + await asyncio.gather(queue_task, webserver_task) except asyncio.CancelledError: - logger.info("Main task is cancelled.") + logger.info("Main task was cancelled, initiating shutdown.") except Exception as e: logger.error(f"An error occurred while running async queues: {e}", exc_info=True) sys.exit(1) finally: logger.info("Signal received, shutting down...") + if queue_task: + queue_task.cancel() + if webserver_task: + webserver_task.cancel() + await manager.shutdown() + await asyncio.gather(queue_task, webserver_task, return_exceptions=True) -# async def run_async_queues(manager, app, port, host): -# server = None -# try: -# await manager.run() -# server = await asyncio.start_server(app, host, port) -# await server.serve_forever() -# except Exception as e: -# logger.error(f"An error occurred while running async queues: {e}") -# finally: -# if server: -# server.close() -# await server.wait_closed() -# await manager.shutdown() + +async def graceful_shutdown(queue_task: asyncio.Task, webserver_task: asyncio.Task, manager: AsyncQueueManager): + """Ensure the graceful shutdown of tasks. + + Args: + queue_task (asyncio.Task): Task instance of manager.run() + webserver_task (asyncio.Task): Task instance of webserver + manager (AsyncQueueManager): Queue manager object + """ + try: + if queue_task: + queue_task.cancel() + if webserver_task: + webserver_task.cancel() + + await manager.shutdown() + + await asyncio.gather(queue_task, webserver_task, return_exceptions=True) + logger.info("Shutdown complete.") + except Exception as e: + logger.error(f"Error during graceful shutdown: {e}", exc_info=True) def start_standard_queue_consumer( diff --git a/pyinfra/queue/async_manager.py b/pyinfra/queue/async_manager.py index c3e6df0..802d35a 100644 --- a/pyinfra/queue/async_manager.py +++ b/pyinfra/queue/async_manager.py @@ -295,17 +295,29 @@ class AsyncQueueManager: async def shutdown(self) -> None: logger.info("Shutting down RabbitMQ handler...") - if self.channel: - # Cancel queues to stop fetching messages - logger.debug("Cancelling queues...") - for tenant, queue in self.tenant_queues.items(): - await queue.cancel(self.consumer_tags[tenant]) - await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"]) - while self.message_count != 0: - logger.debug(f"Messages are still being processed: {self.message_count=} ") - await asyncio.sleep(2) - await self.channel.close() - if self.connection: - await self.connection.close() + try: + if self.channel: + # Cancel queues to stop fetching messages + logger.debug("Cancelling queues...") + for tenant, queue in self.tenant_queues.items(): + await queue.cancel(self.consumer_tags[tenant]) + await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"]) + while self.message_count != 0: + logger.debug(f"Messages are still being processed: {self.message_count=} ") + await asyncio.sleep(2) + await self.channel.close() + logger.debug("Channel closed.") + else: + logger.debug("No channel to close.") + except Exception as e: + logger.error(f"Error during shutdown: {e}") + try: + if self.connection: + await self.connection.close() + else: + logger.debug("No connection to close.") + except Exception as e: + logger.error(f"Error closing connection: {e}") + logger.info("RabbitMQ handler shut down successfully.") - sys.exit(0) + # sys.exit(0) diff --git a/pyinfra/webserver/utils.py b/pyinfra/webserver/utils.py index b87f380..478ffe7 100644 --- a/pyinfra/webserver/utils.py +++ b/pyinfra/webserver/utils.py @@ -65,6 +65,10 @@ async def run_async_webserver(app: FastAPI, port: int, host: str): await server.serve() except asyncio.CancelledError: logger.info("Webserver was cancelled.") + server.should_exit = True + await server.shutdown() + except Exception as e: + logger.error(f"Error while running the webserver: {e}", exc_info=True) finally: logger.info("Webserver has been shut down.") From f855224e290c910fd33fe8d3cc140e42c5f30b9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20K=C3=B6ssler?= Date: Fri, 27 Sep 2024 10:00:41 +0200 Subject: [PATCH 4/5] feat: add on close callback --- pyinfra/queue/async_manager.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/pyinfra/queue/async_manager.py b/pyinfra/queue/async_manager.py index 802d35a..2800920 100644 --- a/pyinfra/queue/async_manager.py +++ b/pyinfra/queue/async_manager.py @@ -1,8 +1,6 @@ import asyncio import concurrent.futures import json -import signal -import sys from dataclasses import dataclass, field from typing import Any, Callable, Dict, Set @@ -16,7 +14,6 @@ from aio_pika.abc import ( AbstractQueue, ) from aio_pika.exceptions import AMQPConnectionError, ChannelInvalidStateError -from aiormq.exceptions import AMQPConnectionError from kn_utils.logging import logger from tenacity import ( retry, @@ -52,6 +49,9 @@ class RabbitMQConfig: "login": self.username, "password": self.password, "client_properties": {"heartbeat": self.heartbeat}, + # aio_pika automatically and infinitely tries to reconnect to the broker when the connection is closed, + # which we don't want in order to enable scale to zero. Therefore, reconnect_interval is set to None. + # "reconnect_interval": None, } @@ -88,10 +88,35 @@ class AsyncQueueManager: async def connect(self) -> None: logger.info("Attempting to connect to RabbitMQ...") self.connection = await connect_robust(**self.config.connection_params) + self.connection.close_callbacks.add(self.on_connection_close) self.channel = await self.connection.channel() await self.channel.set_qos(prefetch_count=1) logger.info("Successfully connected to RabbitMQ") + async def on_connection_close(self, sender, exc): + """This is a callback for unexpected connection closures.""" + logger.debug(f"Sender: {sender}") + logger.error("Connection to RabbitMQ lost: %s. Attempting to reconnect...", exc) + try: + await self.reconnect() + except Exception as e: + logger.error("Reconnection failed. Shutting down: %s", e) + await self.shutdown() + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=retry_if_exception_type(AMQPConnectionError), + reraise=True, + ) + async def reconnect(self) -> None: + logger.info("Attempting to reconnect to RabbitMQ...") + await self.connect() + await self.setup_exchanges() + await self.initialize_tenant_queues() + await self.setup_tenant_queue() + logger.info("Reconnected to RabbitMQ successfully. Press CTRL+C to exit.") + async def is_ready(self) -> bool: if self.connection is None or self.connection.is_closed: try: @@ -320,4 +345,3 @@ class AsyncQueueManager: logger.error(f"Error closing connection: {e}") logger.info("RabbitMQ handler shut down successfully.") - # sys.exit(0) From 45377ba1722b891210bb42d197c74caf80e8529c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20K=C3=B6ssler?= Date: Fri, 27 Sep 2024 17:11:10 +0200 Subject: [PATCH 5/5] feat: improve on close callback and simplify exception handling --- pyinfra/examples.py | 40 +++------------- pyinfra/queue/async_manager.py | 86 +++++++++++++++++----------------- 2 files changed, 49 insertions(+), 77 deletions(-) diff --git a/pyinfra/examples.py b/pyinfra/examples.py index 7aca2ab..ddd45ea 100644 --- a/pyinfra/examples.py +++ b/pyinfra/examples.py @@ -1,6 +1,7 @@ import asyncio import sys +from aiormq.exceptions import AMQPConnectionError from dynaconf import Dynaconf from fastapi import FastAPI from kn_utils.logging import logger @@ -27,31 +28,26 @@ from pyinfra.webserver.utils import ( ) -@retry( - stop=stop_after_attempt(5), - wait=wait_exponential(multiplier=1, min=4, max=60), - retry=retry_if_exception_type((Exception,)), # You might want to be more specific here - reraise=True, -) async def run_async_queues(manager: AsyncQueueManager, app, port, host): """Run the async webserver and the async queue manager concurrently.""" queue_task = None webserver_task = None try: - queue_task = asyncio.create_task(manager.run()) - webserver_task = asyncio.create_task(run_async_webserver(app, port, host)) - + queue_task = asyncio.create_task(manager.run(), name="queues") + webserver_task = asyncio.create_task(run_async_webserver(app, port, host), name="webserver") await asyncio.gather(queue_task, webserver_task) except asyncio.CancelledError: logger.info("Main task was cancelled, initiating shutdown.") + except AMQPConnectionError as e: + logger.warning(f"AMQPConnectionError: {e} - shutting down.") except Exception as e: logger.error(f"An error occurred while running async queues: {e}", exc_info=True) sys.exit(1) finally: logger.info("Signal received, shutting down...") - if queue_task: + if queue_task and not queue_task.done(): queue_task.cancel() - if webserver_task: + if webserver_task and not webserver_task.done(): webserver_task.cancel() await manager.shutdown() @@ -59,28 +55,6 @@ async def run_async_queues(manager: AsyncQueueManager, app, port, host): await asyncio.gather(queue_task, webserver_task, return_exceptions=True) -async def graceful_shutdown(queue_task: asyncio.Task, webserver_task: asyncio.Task, manager: AsyncQueueManager): - """Ensure the graceful shutdown of tasks. - - Args: - queue_task (asyncio.Task): Task instance of manager.run() - webserver_task (asyncio.Task): Task instance of webserver - manager (AsyncQueueManager): Queue manager object - """ - try: - if queue_task: - queue_task.cancel() - if webserver_task: - webserver_task.cancel() - - await manager.shutdown() - - await asyncio.gather(queue_task, webserver_task, return_exceptions=True) - logger.info("Shutdown complete.") - except Exception as e: - logger.error(f"Error during graceful shutdown: {e}", exc_info=True) - - def start_standard_queue_consumer( callback: Callback, settings: Dynaconf, diff --git a/pyinfra/queue/async_manager.py b/pyinfra/queue/async_manager.py index 2800920..4d31163 100644 --- a/pyinfra/queue/async_manager.py +++ b/pyinfra/queue/async_manager.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from typing import Any, Callable, Dict, Set import aiohttp -from aio_pika import ExchangeType, IncomingMessage, Message, connect_robust +from aio_pika import ExchangeType, IncomingMessage, Message, connect from aio_pika.abc import ( AbstractChannel, AbstractConnection, @@ -13,7 +13,12 @@ from aio_pika.abc import ( AbstractIncomingMessage, AbstractQueue, ) -from aio_pika.exceptions import AMQPConnectionError, ChannelInvalidStateError +from aio_pika.exceptions import ( + ChannelClosed, + ChannelInvalidStateError, + ConnectionClosed, +) +from aiormq.exceptions import AMQPConnectionError from kn_utils.logging import logger from tenacity import ( retry, @@ -49,9 +54,6 @@ class RabbitMQConfig: "login": self.username, "password": self.password, "client_properties": {"heartbeat": self.heartbeat}, - # aio_pika automatically and infinitely tries to reconnect to the broker when the connection is closed, - # which we don't want in order to enable scale to zero. Therefore, reconnect_interval is set to None. - # "reconnect_interval": None, } @@ -87,7 +89,7 @@ class AsyncQueueManager: ) async def connect(self) -> None: logger.info("Attempting to connect to RabbitMQ...") - self.connection = await connect_robust(**self.config.connection_params) + self.connection = await connect(**self.config.connection_params) self.connection.close_callbacks.add(self.on_connection_close) self.channel = await self.connection.channel() await self.channel.set_qos(prefetch_count=1) @@ -96,26 +98,18 @@ class AsyncQueueManager: async def on_connection_close(self, sender, exc): """This is a callback for unexpected connection closures.""" logger.debug(f"Sender: {sender}") - logger.error("Connection to RabbitMQ lost: %s. Attempting to reconnect...", exc) - try: - await self.reconnect() - except Exception as e: - logger.error("Reconnection failed. Shutting down: %s", e) - await self.shutdown() - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=60), - retry=retry_if_exception_type(AMQPConnectionError), - reraise=True, - ) - async def reconnect(self) -> None: - logger.info("Attempting to reconnect to RabbitMQ...") - await self.connect() - await self.setup_exchanges() - await self.initialize_tenant_queues() - await self.setup_tenant_queue() - logger.info("Reconnected to RabbitMQ successfully. Press CTRL+C to exit.") + if isinstance(exc, ConnectionClosed): + logger.warning("Connection to RabbitMQ lost. Attempting to reconnect...") + try: + await self.run() + logger.debug("Reconnected to RabbitMQ successfully") + except Exception as e: + logger.warning(f"Failed to reconnect to RabbitMQ: {e}") + # cancel queue manager and webserver to shutdown service + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + [task.cancel() for task in tasks if task.get_name() in ["queues", "webserver"]] + else: + logger.debug("Connection closed on purpose.") async def is_ready(self) -> bool: if self.connection is None or self.connection.is_closed: @@ -303,25 +297,16 @@ class AsyncQueueManager: await self.create_tenant_queues(tenant_id) async def run(self) -> None: - try: - await self.connect() - await self.setup_exchanges() - await self.initialize_tenant_queues() - await self.setup_tenant_queue() + await self.connect() + await self.setup_exchanges() + await self.initialize_tenant_queues() + await self.setup_tenant_queue() - logger.info("RabbitMQ handler is running. Press CTRL+C to exit.") - except AMQPConnectionError as e: - logger.error(f"Failed to establish connection to RabbitMQ: {e}", exc_info=True) - # TODO: implement a custom exception handling strategy here - except asyncio.CancelledError: - logger.warning("Operation cancelled.") - except Exception as e: - logger.error(f"An error occurred: {e}", exc_info=True) + logger.info("RabbitMQ handler is running. Press CTRL+C to exit.") - async def shutdown(self) -> None: - logger.info("Shutting down RabbitMQ handler...") + async def close_channels(self) -> None: try: - if self.channel: + if self.channel and not self.channel.is_closed: # Cancel queues to stop fetching messages logger.debug("Cancelling queues...") for tenant, queue in self.tenant_queues.items(): @@ -334,14 +319,27 @@ class AsyncQueueManager: logger.debug("Channel closed.") else: logger.debug("No channel to close.") + except ChannelClosed: + logger.warning("Channel was already closed.") + except ConnectionClosed: + logger.warning("Connection was lost, unable to close channel.") except Exception as e: - logger.error(f"Error during shutdown: {e}") + logger.error(f"Error during channel shutdown: {e}") + + async def close_connection(self) -> None: try: - if self.connection: + if self.connection and not self.connection.is_closed: await self.connection.close() + logger.debug("Connection closed.") else: logger.debug("No connection to close.") + except ConnectionClosed: + logger.warning("Connection was already closed.") except Exception as e: logger.error(f"Error closing connection: {e}") + async def shutdown(self) -> None: + logger.info("Shutting down RabbitMQ handler...") + await self.close_channels() + await self.close_connection() logger.info("RabbitMQ handler shut down successfully.")