diff --git a/pyinfra/examples.py b/pyinfra/examples.py index d7fde30..ddd45ea 100644 --- a/pyinfra/examples.py +++ b/pyinfra/examples.py @@ -1,6 +1,7 @@ import asyncio import sys +from aiormq.exceptions import AMQPConnectionError from dynaconf import Dynaconf from fastapi import FastAPI from kn_utils.logging import logger @@ -27,40 +28,31 @@ from pyinfra.webserver.utils import ( ) -@retry( - stop=stop_after_attempt(5), - wait=wait_exponential(multiplier=1, min=4, max=60), - retry=retry_if_exception_type((Exception,)), # You might want to be more specific here - reraise=True, -) -async def run_async_queues(manager, app, port, host): +async def run_async_queues(manager: AsyncQueueManager, app, port, host): """Run the async webserver and the async queue manager concurrently.""" + queue_task = None + webserver_task = None try: - await manager.run() - await run_async_webserver(app, port, host) + queue_task = asyncio.create_task(manager.run(), name="queues") + webserver_task = asyncio.create_task(run_async_webserver(app, port, host), name="webserver") + await asyncio.gather(queue_task, webserver_task) except asyncio.CancelledError: - logger.info("Main task is cancelled.") + logger.info("Main task was cancelled, initiating shutdown.") + except AMQPConnectionError as e: + logger.warning(f"AMQPConnectionError: {e} - shutting down.") except Exception as e: logger.error(f"An error occurred while running async queues: {e}", exc_info=True) sys.exit(1) finally: logger.info("Signal received, shutting down...") + if queue_task and not queue_task.done(): + queue_task.cancel() + if webserver_task and not webserver_task.done(): + webserver_task.cancel() + await manager.shutdown() - -# async def run_async_queues(manager, app, port, host): -# server = None -# try: -# await manager.run() -# server = await asyncio.start_server(app, host, port) -# await server.serve_forever() -# except Exception as e: -# logger.error(f"An error occurred while running async queues: {e}") -# finally: -# if server: -# server.close() -# await server.wait_closed() -# await manager.shutdown() + await asyncio.gather(queue_task, webserver_task, return_exceptions=True) def start_standard_queue_consumer( diff --git a/pyinfra/queue/async_manager.py b/pyinfra/queue/async_manager.py index c3e6df0..4d31163 100644 --- a/pyinfra/queue/async_manager.py +++ b/pyinfra/queue/async_manager.py @@ -1,13 +1,11 @@ import asyncio import concurrent.futures import json -import signal -import sys from dataclasses import dataclass, field from typing import Any, Callable, Dict, Set import aiohttp -from aio_pika import ExchangeType, IncomingMessage, Message, connect_robust +from aio_pika import ExchangeType, IncomingMessage, Message, connect from aio_pika.abc import ( AbstractChannel, AbstractConnection, @@ -15,7 +13,11 @@ from aio_pika.abc import ( AbstractIncomingMessage, AbstractQueue, ) -from aio_pika.exceptions import AMQPConnectionError, ChannelInvalidStateError +from aio_pika.exceptions import ( + ChannelClosed, + ChannelInvalidStateError, + ConnectionClosed, +) from aiormq.exceptions import AMQPConnectionError from kn_utils.logging import logger from tenacity import ( @@ -87,11 +89,28 @@ class AsyncQueueManager: ) async def connect(self) -> None: logger.info("Attempting to connect to RabbitMQ...") - self.connection = await connect_robust(**self.config.connection_params) + self.connection = await connect(**self.config.connection_params) + self.connection.close_callbacks.add(self.on_connection_close) self.channel = await self.connection.channel() await self.channel.set_qos(prefetch_count=1) logger.info("Successfully connected to RabbitMQ") + async def on_connection_close(self, sender, exc): + """This is a callback for unexpected connection closures.""" + logger.debug(f"Sender: {sender}") + if isinstance(exc, ConnectionClosed): + logger.warning("Connection to RabbitMQ lost. Attempting to reconnect...") + try: + await self.run() + logger.debug("Reconnected to RabbitMQ successfully") + except Exception as e: + logger.warning(f"Failed to reconnect to RabbitMQ: {e}") + # cancel queue manager and webserver to shutdown service + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + [task.cancel() for task in tasks if task.get_name() in ["queues", "webserver"]] + else: + logger.debug("Connection closed on purpose.") + async def is_ready(self) -> bool: if self.connection is None or self.connection.is_closed: try: @@ -278,34 +297,49 @@ class AsyncQueueManager: await self.create_tenant_queues(tenant_id) async def run(self) -> None: - try: - await self.connect() - await self.setup_exchanges() - await self.initialize_tenant_queues() - await self.setup_tenant_queue() + await self.connect() + await self.setup_exchanges() + await self.initialize_tenant_queues() + await self.setup_tenant_queue() - logger.info("RabbitMQ handler is running. Press CTRL+C to exit.") - except AMQPConnectionError as e: - logger.error(f"Failed to establish connection to RabbitMQ: {e}", exc_info=True) - # TODO: implement a custom exception handling strategy here - except asyncio.CancelledError: - logger.warning("Operation cancelled.") + logger.info("RabbitMQ handler is running. Press CTRL+C to exit.") + + async def close_channels(self) -> None: + try: + if self.channel and not self.channel.is_closed: + # Cancel queues to stop fetching messages + logger.debug("Cancelling queues...") + for tenant, queue in self.tenant_queues.items(): + await queue.cancel(self.consumer_tags[tenant]) + await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"]) + while self.message_count != 0: + logger.debug(f"Messages are still being processed: {self.message_count=} ") + await asyncio.sleep(2) + await self.channel.close() + logger.debug("Channel closed.") + else: + logger.debug("No channel to close.") + except ChannelClosed: + logger.warning("Channel was already closed.") + except ConnectionClosed: + logger.warning("Connection was lost, unable to close channel.") except Exception as e: - logger.error(f"An error occurred: {e}", exc_info=True) + logger.error(f"Error during channel shutdown: {e}") + + async def close_connection(self) -> None: + try: + if self.connection and not self.connection.is_closed: + await self.connection.close() + logger.debug("Connection closed.") + else: + logger.debug("No connection to close.") + except ConnectionClosed: + logger.warning("Connection was already closed.") + except Exception as e: + logger.error(f"Error closing connection: {e}") async def shutdown(self) -> None: logger.info("Shutting down RabbitMQ handler...") - if self.channel: - # Cancel queues to stop fetching messages - logger.debug("Cancelling queues...") - for tenant, queue in self.tenant_queues.items(): - await queue.cancel(self.consumer_tags[tenant]) - await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"]) - while self.message_count != 0: - logger.debug(f"Messages are still being processed: {self.message_count=} ") - await asyncio.sleep(2) - await self.channel.close() - if self.connection: - await self.connection.close() + await self.close_channels() + await self.close_connection() logger.info("RabbitMQ handler shut down successfully.") - sys.exit(0) diff --git a/pyinfra/webserver/utils.py b/pyinfra/webserver/utils.py index 5debb15..478ffe7 100644 --- a/pyinfra/webserver/utils.py +++ b/pyinfra/webserver/utils.py @@ -4,13 +4,18 @@ import logging import signal import threading import time -from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type from typing import Callable import uvicorn from dynaconf import Dynaconf from fastapi import FastAPI from kn_utils.logging import logger +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from pyinfra.config.loader import validate_settings from pyinfra.config.validators import webserver_validators @@ -56,20 +61,16 @@ async def run_async_webserver(app: FastAPI, port: int, host: str): config = uvicorn.Config(app, host=host, port=port, log_level=logging.WARNING) server = uvicorn.Server(config) - async def shutdown(signal): - logger.info(f"Received signal {signal.name}, shutting down webserver...") - await app.shutdown() - await app.cleanup() - logger.info("Shutdown complete.") - - loop = asyncio.get_event_loop() - for sig in (signal.SIGTERM, signal.SIGINT): - loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(shutdown(s))) - try: await server.serve() except asyncio.CancelledError: - pass + logger.info("Webserver was cancelled.") + server.should_exit = True + await server.shutdown() + except Exception as e: + logger.error(f"Error while running the webserver: {e}", exc_info=True) + finally: + logger.info("Webserver has been shut down.") HealthFunction = Callable[[], bool] diff --git a/pyproject.toml b/pyproject.toml index a39299b..f4aee89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pyinfra" -version = "3.2.10" +version = "3.2.11" description = "" authors = ["Team Research "] license = "All rights reseverd"