From 45377ba1722b891210bb42d197c74caf80e8529c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20K=C3=B6ssler?= Date: Fri, 27 Sep 2024 17:11:10 +0200 Subject: [PATCH] feat: improve on close callback and simplify exception handling --- pyinfra/examples.py | 40 +++------------- pyinfra/queue/async_manager.py | 86 +++++++++++++++++----------------- 2 files changed, 49 insertions(+), 77 deletions(-) diff --git a/pyinfra/examples.py b/pyinfra/examples.py index 7aca2ab..ddd45ea 100644 --- a/pyinfra/examples.py +++ b/pyinfra/examples.py @@ -1,6 +1,7 @@ import asyncio import sys +from aiormq.exceptions import AMQPConnectionError from dynaconf import Dynaconf from fastapi import FastAPI from kn_utils.logging import logger @@ -27,31 +28,26 @@ from pyinfra.webserver.utils import ( ) -@retry( - stop=stop_after_attempt(5), - wait=wait_exponential(multiplier=1, min=4, max=60), - retry=retry_if_exception_type((Exception,)), # You might want to be more specific here - reraise=True, -) async def run_async_queues(manager: AsyncQueueManager, app, port, host): """Run the async webserver and the async queue manager concurrently.""" queue_task = None webserver_task = None try: - queue_task = asyncio.create_task(manager.run()) - webserver_task = asyncio.create_task(run_async_webserver(app, port, host)) - + queue_task = asyncio.create_task(manager.run(), name="queues") + webserver_task = asyncio.create_task(run_async_webserver(app, port, host), name="webserver") await asyncio.gather(queue_task, webserver_task) except asyncio.CancelledError: logger.info("Main task was cancelled, initiating shutdown.") + except AMQPConnectionError as e: + logger.warning(f"AMQPConnectionError: {e} - shutting down.") except Exception as e: logger.error(f"An error occurred while running async queues: {e}", exc_info=True) sys.exit(1) finally: logger.info("Signal received, shutting down...") - if queue_task: + if queue_task and not queue_task.done(): queue_task.cancel() - if webserver_task: + if webserver_task and not webserver_task.done(): webserver_task.cancel() await manager.shutdown() @@ -59,28 +55,6 @@ async def run_async_queues(manager: AsyncQueueManager, app, port, host): await asyncio.gather(queue_task, webserver_task, return_exceptions=True) -async def graceful_shutdown(queue_task: asyncio.Task, webserver_task: asyncio.Task, manager: AsyncQueueManager): - """Ensure the graceful shutdown of tasks. - - Args: - queue_task (asyncio.Task): Task instance of manager.run() - webserver_task (asyncio.Task): Task instance of webserver - manager (AsyncQueueManager): Queue manager object - """ - try: - if queue_task: - queue_task.cancel() - if webserver_task: - webserver_task.cancel() - - await manager.shutdown() - - await asyncio.gather(queue_task, webserver_task, return_exceptions=True) - logger.info("Shutdown complete.") - except Exception as e: - logger.error(f"Error during graceful shutdown: {e}", exc_info=True) - - def start_standard_queue_consumer( callback: Callback, settings: Dynaconf, diff --git a/pyinfra/queue/async_manager.py b/pyinfra/queue/async_manager.py index 2800920..4d31163 100644 --- a/pyinfra/queue/async_manager.py +++ b/pyinfra/queue/async_manager.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from typing import Any, Callable, Dict, Set import aiohttp -from aio_pika import ExchangeType, IncomingMessage, Message, connect_robust +from aio_pika import ExchangeType, IncomingMessage, Message, connect from aio_pika.abc import ( AbstractChannel, AbstractConnection, @@ -13,7 +13,12 @@ from aio_pika.abc import ( AbstractIncomingMessage, AbstractQueue, ) -from aio_pika.exceptions import AMQPConnectionError, ChannelInvalidStateError +from aio_pika.exceptions import ( + ChannelClosed, + ChannelInvalidStateError, + ConnectionClosed, +) +from aiormq.exceptions import AMQPConnectionError from kn_utils.logging import logger from tenacity import ( retry, @@ -49,9 +54,6 @@ class RabbitMQConfig: "login": self.username, "password": self.password, "client_properties": {"heartbeat": self.heartbeat}, - # aio_pika automatically and infinitely tries to reconnect to the broker when the connection is closed, - # which we don't want in order to enable scale to zero. Therefore, reconnect_interval is set to None. - # "reconnect_interval": None, } @@ -87,7 +89,7 @@ class AsyncQueueManager: ) async def connect(self) -> None: logger.info("Attempting to connect to RabbitMQ...") - self.connection = await connect_robust(**self.config.connection_params) + self.connection = await connect(**self.config.connection_params) self.connection.close_callbacks.add(self.on_connection_close) self.channel = await self.connection.channel() await self.channel.set_qos(prefetch_count=1) @@ -96,26 +98,18 @@ class AsyncQueueManager: async def on_connection_close(self, sender, exc): """This is a callback for unexpected connection closures.""" logger.debug(f"Sender: {sender}") - logger.error("Connection to RabbitMQ lost: %s. Attempting to reconnect...", exc) - try: - await self.reconnect() - except Exception as e: - logger.error("Reconnection failed. Shutting down: %s", e) - await self.shutdown() - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=60), - retry=retry_if_exception_type(AMQPConnectionError), - reraise=True, - ) - async def reconnect(self) -> None: - logger.info("Attempting to reconnect to RabbitMQ...") - await self.connect() - await self.setup_exchanges() - await self.initialize_tenant_queues() - await self.setup_tenant_queue() - logger.info("Reconnected to RabbitMQ successfully. Press CTRL+C to exit.") + if isinstance(exc, ConnectionClosed): + logger.warning("Connection to RabbitMQ lost. Attempting to reconnect...") + try: + await self.run() + logger.debug("Reconnected to RabbitMQ successfully") + except Exception as e: + logger.warning(f"Failed to reconnect to RabbitMQ: {e}") + # cancel queue manager and webserver to shutdown service + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + [task.cancel() for task in tasks if task.get_name() in ["queues", "webserver"]] + else: + logger.debug("Connection closed on purpose.") async def is_ready(self) -> bool: if self.connection is None or self.connection.is_closed: @@ -303,25 +297,16 @@ class AsyncQueueManager: await self.create_tenant_queues(tenant_id) async def run(self) -> None: - try: - await self.connect() - await self.setup_exchanges() - await self.initialize_tenant_queues() - await self.setup_tenant_queue() + await self.connect() + await self.setup_exchanges() + await self.initialize_tenant_queues() + await self.setup_tenant_queue() - logger.info("RabbitMQ handler is running. Press CTRL+C to exit.") - except AMQPConnectionError as e: - logger.error(f"Failed to establish connection to RabbitMQ: {e}", exc_info=True) - # TODO: implement a custom exception handling strategy here - except asyncio.CancelledError: - logger.warning("Operation cancelled.") - except Exception as e: - logger.error(f"An error occurred: {e}", exc_info=True) + logger.info("RabbitMQ handler is running. Press CTRL+C to exit.") - async def shutdown(self) -> None: - logger.info("Shutting down RabbitMQ handler...") + async def close_channels(self) -> None: try: - if self.channel: + if self.channel and not self.channel.is_closed: # Cancel queues to stop fetching messages logger.debug("Cancelling queues...") for tenant, queue in self.tenant_queues.items(): @@ -334,14 +319,27 @@ class AsyncQueueManager: logger.debug("Channel closed.") else: logger.debug("No channel to close.") + except ChannelClosed: + logger.warning("Channel was already closed.") + except ConnectionClosed: + logger.warning("Connection was lost, unable to close channel.") except Exception as e: - logger.error(f"Error during shutdown: {e}") + logger.error(f"Error during channel shutdown: {e}") + + async def close_connection(self) -> None: try: - if self.connection: + if self.connection and not self.connection.is_closed: await self.connection.close() + logger.debug("Connection closed.") else: logger.debug("No connection to close.") + except ConnectionClosed: + logger.warning("Connection was already closed.") except Exception as e: logger.error(f"Error closing connection: {e}") + async def shutdown(self) -> None: + logger.info("Shutting down RabbitMQ handler...") + await self.close_channels() + await self.close_connection() logger.info("RabbitMQ handler shut down successfully.")