From 47b42e95e2758a4f0be58b32e0ec0660aaf73309 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20K=C3=B6ssler?= Date: Thu, 1 Aug 2024 15:31:58 +0200 Subject: [PATCH 1/2] refactor: graceful shutdown --- pyinfra/queue/async_manager.py | 28 +++++++++++++++++++++++----- pyinfra/queue/manager.py | 15 +++++++++++++-- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/pyinfra/queue/async_manager.py b/pyinfra/queue/async_manager.py index c435ef3..17a50eb 100644 --- a/pyinfra/queue/async_manager.py +++ b/pyinfra/queue/async_manager.py @@ -1,6 +1,7 @@ import asyncio import json import signal +import time from dataclasses import dataclass, field from typing import Any, Callable, Dict, Set @@ -11,6 +12,7 @@ from aio_pika.abc import ( AbstractConnection, AbstractExchange, AbstractIncomingMessage, + AbstractQueue, ) from kn_utils.logging import logger from tenacity import ( @@ -65,7 +67,11 @@ class AsyncQueueManager: self.tenant_exchange: AbstractExchange | None = None self.input_exchange: AbstractExchange | None = None self.output_exchange: AbstractExchange | None = None + self.tenant_exchange_queue: AbstractQueue | None = None self.tenant_queues: Dict[str, AbstractChannel] = {} + self.consumer_tags: Dict[str, str] = {} + + self.message_count: int = 0 async def connect(self) -> None: self.connection = await connect_robust(**self.config.connection_params) @@ -88,7 +94,7 @@ class AsyncQueueManager: ) async def setup_tenant_queue(self) -> None: - queue = await self.channel.declare_queue( + self.tenant_exchange_queue = await self.channel.declare_queue( f"{self.config.pod_name}_{self.config.tenant_event_queue_suffix}", durable=True, arguments={ @@ -98,8 +104,10 @@ class AsyncQueueManager: "x-max-priority": 2, }, ) - await queue.bind(self.tenant_exchange, routing_key="tenant.*") - await queue.consume(self.process_tenant_message) + await self.tenant_exchange_queue.bind(self.tenant_exchange, routing_key="tenant.*") + self.consumer_tags["tenant_exchange_queue"] = await self.tenant_exchange_queue.consume( + self.process_tenant_message + ) async def process_tenant_message(self, message: AbstractIncomingMessage) -> None: async with message.process(): @@ -127,8 +135,7 @@ class AsyncQueueManager: }, ) await input_queue.bind(self.input_exchange, routing_key=tenant_id) - await input_queue.consume(self.process_input_message) - + self.consumer_tags[tenant_id] = await input_queue.consume(self.process_input_message) self.tenant_queues[tenant_id] = input_queue logger.info(f"Created queues for tenant {tenant_id}") @@ -137,6 +144,7 @@ class AsyncQueueManager: # somehow queue.delete() does not work here await self.channel.queue_delete(f"{self.config.input_queue_prefix}_{tenant_id}") del self.tenant_queues[tenant_id] + del self.consumer_tags[tenant_id] logger.info(f"Deleted queues for tenant {tenant_id}") async def process_input_message(self, message: IncomingMessage) -> None: @@ -157,6 +165,7 @@ class AsyncQueueManager: return try: + self.message_count += 1 tenant_id = message.routing_key filtered_message_headers = ( @@ -174,6 +183,7 @@ class AsyncQueueManager: await self.publish_to_output_exchange(tenant_id, result, filtered_message_headers) await message.ack() logger.debug(f"Message with {message.delivery_tag=} acknowledged.") + self.message_count -= 1 else: raise ValueError(f"Could not process message with {message.body=}.") @@ -252,6 +262,14 @@ class AsyncQueueManager: async def shutdown(self) -> None: logger.info("Shutting down RabbitMQ handler...") if self.channel: + # Cancel queues to stop fetching messages + logger.debug("Cancelling queues...") + for tenant, queue in self.tenant_queues.items(): + await queue.cancel(self.consumer_tags[tenant]) + await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"]) + while self.message_count != 0: + logger.debug(f"Messages are still being processed: {self.message_count=} ") + time.sleep(2) await self.channel.close() if self.connection: await self.connection.close() diff --git a/pyinfra/queue/manager.py b/pyinfra/queue/manager.py index bb30d1d..b1d1a77 100644 --- a/pyinfra/queue/manager.py +++ b/pyinfra/queue/manager.py @@ -35,6 +35,8 @@ class QueueManager: self.connection: Union[BlockingConnection, None] = None self.channel: Union[BlockingChannel, None] = None self.connection_sleep = settings.rabbitmq.connection_sleep + self.processing_callback = False + self.received_signal = False atexit.register(self.stop_consuming) signal.signal(signal.SIGTERM, self._handle_stop_signal) @@ -151,6 +153,7 @@ class QueueManager: def on_message_callback(channel, method, properties, body): logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.") + self.processing_callback = True if method.redelivered: logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.") @@ -192,9 +195,17 @@ class QueueManager: channel.basic_nack(method.delivery_tag, requeue=False) raise + finally: + self.processing_callback = False + if self.received_signal: + self.stop_consuming() + sys.exit(0) + return on_message_callback def _handle_stop_signal(self, signum, *args, **kwargs): logger.info(f"Received signal {signum}, stopping consuming...") - self.stop_consuming() - sys.exit(0) + self.received_signal = True + if not self.processing_callback: + self.stop_consuming() + sys.exit(0) From cb8509b1206ef68d419d500c2dcbbd1bb4d0d6d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20K=C3=B6ssler?= Date: Thu, 1 Aug 2024 17:42:59 +0200 Subject: [PATCH 2/2] refactor: message counter --- pyinfra/queue/async_manager.py | 6 ++++-- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pyinfra/queue/async_manager.py b/pyinfra/queue/async_manager.py index 17a50eb..75efd8e 100644 --- a/pyinfra/queue/async_manager.py +++ b/pyinfra/queue/async_manager.py @@ -164,8 +164,9 @@ class AsyncQueueManager: await self.shutdown() return + self.message_count += 1 + try: - self.message_count += 1 tenant_id = message.routing_key filtered_message_headers = ( @@ -183,7 +184,6 @@ class AsyncQueueManager: await self.publish_to_output_exchange(tenant_id, result, filtered_message_headers) await message.ack() logger.debug(f"Message with {message.delivery_tag=} acknowledged.") - self.message_count -= 1 else: raise ValueError(f"Could not process message with {message.body=}.") @@ -197,6 +197,8 @@ class AsyncQueueManager: await message.nack(requeue=False) logger.error(f"Error processing input message: {e}", exc_info=True) raise + finally: + self.message_count -= 1 async def publish_to_output_exchange(self, tenant_id: str, result: Dict[str, Any], headers: Dict[str, Any]) -> None: await self.output_exchange.publish( diff --git a/pyproject.toml b/pyproject.toml index 13504bd..2fc214f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pyinfra" -version = "3.0.0" +version = "3.1.0" description = "" authors = ["Team Research "] license = "All rights reseverd"