feat: improve on close callback and simplify exception handling

This commit is contained in:
Jonathan Kössler 2024-09-27 17:11:10 +02:00
parent f855224e29
commit 45377ba172
2 changed files with 49 additions and 77 deletions

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import sys import sys
from aiormq.exceptions import AMQPConnectionError
from dynaconf import Dynaconf from dynaconf import Dynaconf
from fastapi import FastAPI from fastapi import FastAPI
from kn_utils.logging import logger from kn_utils.logging import logger
@ -27,31 +28,26 @@ from pyinfra.webserver.utils import (
) )
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((Exception,)), # You might want to be more specific here
reraise=True,
)
async def run_async_queues(manager: AsyncQueueManager, app, port, host): async def run_async_queues(manager: AsyncQueueManager, app, port, host):
"""Run the async webserver and the async queue manager concurrently.""" """Run the async webserver and the async queue manager concurrently."""
queue_task = None queue_task = None
webserver_task = None webserver_task = None
try: try:
queue_task = asyncio.create_task(manager.run()) queue_task = asyncio.create_task(manager.run(), name="queues")
webserver_task = asyncio.create_task(run_async_webserver(app, port, host)) webserver_task = asyncio.create_task(run_async_webserver(app, port, host), name="webserver")
await asyncio.gather(queue_task, webserver_task) await asyncio.gather(queue_task, webserver_task)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Main task was cancelled, initiating shutdown.") logger.info("Main task was cancelled, initiating shutdown.")
except AMQPConnectionError as e:
logger.warning(f"AMQPConnectionError: {e} - shutting down.")
except Exception as e: except Exception as e:
logger.error(f"An error occurred while running async queues: {e}", exc_info=True) logger.error(f"An error occurred while running async queues: {e}", exc_info=True)
sys.exit(1) sys.exit(1)
finally: finally:
logger.info("Signal received, shutting down...") logger.info("Signal received, shutting down...")
if queue_task: if queue_task and not queue_task.done():
queue_task.cancel() queue_task.cancel()
if webserver_task: if webserver_task and not webserver_task.done():
webserver_task.cancel() webserver_task.cancel()
await manager.shutdown() await manager.shutdown()
@ -59,28 +55,6 @@ async def run_async_queues(manager: AsyncQueueManager, app, port, host):
await asyncio.gather(queue_task, webserver_task, return_exceptions=True) await asyncio.gather(queue_task, webserver_task, return_exceptions=True)
async def graceful_shutdown(queue_task: asyncio.Task, webserver_task: asyncio.Task, manager: AsyncQueueManager):
"""Ensure the graceful shutdown of tasks.
Args:
queue_task (asyncio.Task): Task instance of manager.run()
webserver_task (asyncio.Task): Task instance of webserver
manager (AsyncQueueManager): Queue manager object
"""
try:
if queue_task:
queue_task.cancel()
if webserver_task:
webserver_task.cancel()
await manager.shutdown()
await asyncio.gather(queue_task, webserver_task, return_exceptions=True)
logger.info("Shutdown complete.")
except Exception as e:
logger.error(f"Error during graceful shutdown: {e}", exc_info=True)
def start_standard_queue_consumer( def start_standard_queue_consumer(
callback: Callback, callback: Callback,
settings: Dynaconf, settings: Dynaconf,

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Set from typing import Any, Callable, Dict, Set
import aiohttp import aiohttp
from aio_pika import ExchangeType, IncomingMessage, Message, connect_robust from aio_pika import ExchangeType, IncomingMessage, Message, connect
from aio_pika.abc import ( from aio_pika.abc import (
AbstractChannel, AbstractChannel,
AbstractConnection, AbstractConnection,
@ -13,7 +13,12 @@ from aio_pika.abc import (
AbstractIncomingMessage, AbstractIncomingMessage,
AbstractQueue, AbstractQueue,
) )
from aio_pika.exceptions import AMQPConnectionError, ChannelInvalidStateError from aio_pika.exceptions import (
ChannelClosed,
ChannelInvalidStateError,
ConnectionClosed,
)
from aiormq.exceptions import AMQPConnectionError
from kn_utils.logging import logger from kn_utils.logging import logger
from tenacity import ( from tenacity import (
retry, retry,
@ -49,9 +54,6 @@ class RabbitMQConfig:
"login": self.username, "login": self.username,
"password": self.password, "password": self.password,
"client_properties": {"heartbeat": self.heartbeat}, "client_properties": {"heartbeat": self.heartbeat},
# aio_pika automatically and infinitely tries to reconnect to the broker when the connection is closed,
# which we don't want in order to enable scale to zero. Therefore, reconnect_interval is set to None.
# "reconnect_interval": None,
} }
@ -87,7 +89,7 @@ class AsyncQueueManager:
) )
async def connect(self) -> None: async def connect(self) -> None:
logger.info("Attempting to connect to RabbitMQ...") logger.info("Attempting to connect to RabbitMQ...")
self.connection = await connect_robust(**self.config.connection_params) self.connection = await connect(**self.config.connection_params)
self.connection.close_callbacks.add(self.on_connection_close) self.connection.close_callbacks.add(self.on_connection_close)
self.channel = await self.connection.channel() self.channel = await self.connection.channel()
await self.channel.set_qos(prefetch_count=1) await self.channel.set_qos(prefetch_count=1)
@ -96,26 +98,18 @@ class AsyncQueueManager:
async def on_connection_close(self, sender, exc): async def on_connection_close(self, sender, exc):
"""This is a callback for unexpected connection closures.""" """This is a callback for unexpected connection closures."""
logger.debug(f"Sender: {sender}") logger.debug(f"Sender: {sender}")
logger.error("Connection to RabbitMQ lost: %s. Attempting to reconnect...", exc) if isinstance(exc, ConnectionClosed):
try: logger.warning("Connection to RabbitMQ lost. Attempting to reconnect...")
await self.reconnect() try:
except Exception as e: await self.run()
logger.error("Reconnection failed. Shutting down: %s", e) logger.debug("Reconnected to RabbitMQ successfully")
await self.shutdown() except Exception as e:
logger.warning(f"Failed to reconnect to RabbitMQ: {e}")
@retry( # cancel queue manager and webserver to shutdown service
stop=stop_after_attempt(3), tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
wait=wait_exponential(multiplier=1, min=4, max=60), [task.cancel() for task in tasks if task.get_name() in ["queues", "webserver"]]
retry=retry_if_exception_type(AMQPConnectionError), else:
reraise=True, logger.debug("Connection closed on purpose.")
)
async def reconnect(self) -> None:
logger.info("Attempting to reconnect to RabbitMQ...")
await self.connect()
await self.setup_exchanges()
await self.initialize_tenant_queues()
await self.setup_tenant_queue()
logger.info("Reconnected to RabbitMQ successfully. Press CTRL+C to exit.")
async def is_ready(self) -> bool: async def is_ready(self) -> bool:
if self.connection is None or self.connection.is_closed: if self.connection is None or self.connection.is_closed:
@ -303,25 +297,16 @@ class AsyncQueueManager:
await self.create_tenant_queues(tenant_id) await self.create_tenant_queues(tenant_id)
async def run(self) -> None: async def run(self) -> None:
try: await self.connect()
await self.connect() await self.setup_exchanges()
await self.setup_exchanges() await self.initialize_tenant_queues()
await self.initialize_tenant_queues() await self.setup_tenant_queue()
await self.setup_tenant_queue()
logger.info("RabbitMQ handler is running. Press CTRL+C to exit.") logger.info("RabbitMQ handler is running. Press CTRL+C to exit.")
except AMQPConnectionError as e:
logger.error(f"Failed to establish connection to RabbitMQ: {e}", exc_info=True)
# TODO: implement a custom exception handling strategy here
except asyncio.CancelledError:
logger.warning("Operation cancelled.")
except Exception as e:
logger.error(f"An error occurred: {e}", exc_info=True)
async def shutdown(self) -> None: async def close_channels(self) -> None:
logger.info("Shutting down RabbitMQ handler...")
try: try:
if self.channel: if self.channel and not self.channel.is_closed:
# Cancel queues to stop fetching messages # Cancel queues to stop fetching messages
logger.debug("Cancelling queues...") logger.debug("Cancelling queues...")
for tenant, queue in self.tenant_queues.items(): for tenant, queue in self.tenant_queues.items():
@ -334,14 +319,27 @@ class AsyncQueueManager:
logger.debug("Channel closed.") logger.debug("Channel closed.")
else: else:
logger.debug("No channel to close.") logger.debug("No channel to close.")
except ChannelClosed:
logger.warning("Channel was already closed.")
except ConnectionClosed:
logger.warning("Connection was lost, unable to close channel.")
except Exception as e: except Exception as e:
logger.error(f"Error during shutdown: {e}") logger.error(f"Error during channel shutdown: {e}")
async def close_connection(self) -> None:
try: try:
if self.connection: if self.connection and not self.connection.is_closed:
await self.connection.close() await self.connection.close()
logger.debug("Connection closed.")
else: else:
logger.debug("No connection to close.") logger.debug("No connection to close.")
except ConnectionClosed:
logger.warning("Connection was already closed.")
except Exception as e: except Exception as e:
logger.error(f"Error closing connection: {e}") logger.error(f"Error closing connection: {e}")
async def shutdown(self) -> None:
logger.info("Shutting down RabbitMQ handler...")
await self.close_channels()
await self.close_connection()
logger.info("RabbitMQ handler shut down successfully.") logger.info("RabbitMQ handler shut down successfully.")