diff --git a/pyinfra/examples.py b/pyinfra/examples.py index d7fde30..7aca2ab 100644 --- a/pyinfra/examples.py +++ b/pyinfra/examples.py @@ -33,34 +33,52 @@ from pyinfra.webserver.utils import ( retry=retry_if_exception_type((Exception,)), # You might want to be more specific here reraise=True, ) -async def run_async_queues(manager, app, port, host): +async def run_async_queues(manager: AsyncQueueManager, app, port, host): """Run the async webserver and the async queue manager concurrently.""" + queue_task = None + webserver_task = None try: - await manager.run() - await run_async_webserver(app, port, host) + queue_task = asyncio.create_task(manager.run()) + webserver_task = asyncio.create_task(run_async_webserver(app, port, host)) + + await asyncio.gather(queue_task, webserver_task) except asyncio.CancelledError: - logger.info("Main task is cancelled.") + logger.info("Main task was cancelled, initiating shutdown.") except Exception as e: logger.error(f"An error occurred while running async queues: {e}", exc_info=True) sys.exit(1) finally: logger.info("Signal received, shutting down...") + if queue_task: + queue_task.cancel() + if webserver_task: + webserver_task.cancel() + await manager.shutdown() + await asyncio.gather(queue_task, webserver_task, return_exceptions=True) -# async def run_async_queues(manager, app, port, host): -# server = None -# try: -# await manager.run() -# server = await asyncio.start_server(app, host, port) -# await server.serve_forever() -# except Exception as e: -# logger.error(f"An error occurred while running async queues: {e}") -# finally: -# if server: -# server.close() -# await server.wait_closed() -# await manager.shutdown() + +async def graceful_shutdown(queue_task: asyncio.Task, webserver_task: asyncio.Task, manager: AsyncQueueManager): + """Ensure the graceful shutdown of tasks. + + Args: + queue_task (asyncio.Task): Task instance of manager.run() + webserver_task (asyncio.Task): Task instance of webserver + manager (AsyncQueueManager): Queue manager object + """ + try: + if queue_task: + queue_task.cancel() + if webserver_task: + webserver_task.cancel() + + await manager.shutdown() + + await asyncio.gather(queue_task, webserver_task, return_exceptions=True) + logger.info("Shutdown complete.") + except Exception as e: + logger.error(f"Error during graceful shutdown: {e}", exc_info=True) def start_standard_queue_consumer( diff --git a/pyinfra/queue/async_manager.py b/pyinfra/queue/async_manager.py index c3e6df0..802d35a 100644 --- a/pyinfra/queue/async_manager.py +++ b/pyinfra/queue/async_manager.py @@ -295,17 +295,29 @@ class AsyncQueueManager: async def shutdown(self) -> None: logger.info("Shutting down RabbitMQ handler...") - if self.channel: - # Cancel queues to stop fetching messages - logger.debug("Cancelling queues...") - for tenant, queue in self.tenant_queues.items(): - await queue.cancel(self.consumer_tags[tenant]) - await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"]) - while self.message_count != 0: - logger.debug(f"Messages are still being processed: {self.message_count=} ") - await asyncio.sleep(2) - await self.channel.close() - if self.connection: - await self.connection.close() + try: + if self.channel: + # Cancel queues to stop fetching messages + logger.debug("Cancelling queues...") + for tenant, queue in self.tenant_queues.items(): + await queue.cancel(self.consumer_tags[tenant]) + await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"]) + while self.message_count != 0: + logger.debug(f"Messages are still being processed: {self.message_count=} ") + await asyncio.sleep(2) + await self.channel.close() + logger.debug("Channel closed.") + else: + logger.debug("No channel to close.") + except Exception as e: + logger.error(f"Error during shutdown: {e}") + try: + if self.connection: + await self.connection.close() + else: + logger.debug("No connection to close.") + except Exception as e: + logger.error(f"Error closing connection: {e}") + logger.info("RabbitMQ handler shut down successfully.") - sys.exit(0) + # sys.exit(0) diff --git a/pyinfra/webserver/utils.py b/pyinfra/webserver/utils.py index b87f380..478ffe7 100644 --- a/pyinfra/webserver/utils.py +++ b/pyinfra/webserver/utils.py @@ -65,6 +65,10 @@ async def run_async_webserver(app: FastAPI, port: int, host: str): await server.serve() except asyncio.CancelledError: logger.info("Webserver was cancelled.") + server.should_exit = True + await server.shutdown() + except Exception as e: + logger.error(f"Error while running the webserver: {e}", exc_info=True) finally: logger.info("Webserver has been shut down.")