Merge branch 'refactor/RES-780-graceful-shutdown' into 'master'

refactor: graceful shutdown

See merge request knecon/research/pyinfra!88
This commit is contained in:
Jonathan Kössler 2024-08-02 13:57:04 +02:00
commit fdde56991b
3 changed files with 39 additions and 8 deletions

View File

@ -1,6 +1,7 @@
import asyncio
import json
import signal
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Set
@ -11,6 +12,7 @@ from aio_pika.abc import (
AbstractConnection,
AbstractExchange,
AbstractIncomingMessage,
AbstractQueue,
)
from kn_utils.logging import logger
from tenacity import (
@ -65,7 +67,11 @@ class AsyncQueueManager:
self.tenant_exchange: AbstractExchange | None = None
self.input_exchange: AbstractExchange | None = None
self.output_exchange: AbstractExchange | None = None
self.tenant_exchange_queue: AbstractQueue | None = None
self.tenant_queues: Dict[str, AbstractChannel] = {}
self.consumer_tags: Dict[str, str] = {}
self.message_count: int = 0
async def connect(self) -> None:
self.connection = await connect_robust(**self.config.connection_params)
@ -88,7 +94,7 @@ class AsyncQueueManager:
)
async def setup_tenant_queue(self) -> None:
queue = await self.channel.declare_queue(
self.tenant_exchange_queue = await self.channel.declare_queue(
f"{self.config.pod_name}_{self.config.tenant_event_queue_suffix}",
durable=True,
arguments={
@ -98,8 +104,10 @@ class AsyncQueueManager:
"x-max-priority": 2,
},
)
await queue.bind(self.tenant_exchange, routing_key="tenant.*")
await queue.consume(self.process_tenant_message)
await self.tenant_exchange_queue.bind(self.tenant_exchange, routing_key="tenant.*")
self.consumer_tags["tenant_exchange_queue"] = await self.tenant_exchange_queue.consume(
self.process_tenant_message
)
async def process_tenant_message(self, message: AbstractIncomingMessage) -> None:
async with message.process():
@ -127,8 +135,7 @@ class AsyncQueueManager:
},
)
await input_queue.bind(self.input_exchange, routing_key=tenant_id)
await input_queue.consume(self.process_input_message)
self.consumer_tags[tenant_id] = await input_queue.consume(self.process_input_message)
self.tenant_queues[tenant_id] = input_queue
logger.info(f"Created queues for tenant {tenant_id}")
@ -137,6 +144,7 @@ class AsyncQueueManager:
# somehow queue.delete() does not work here
await self.channel.queue_delete(f"{self.config.input_queue_prefix}_{tenant_id}")
del self.tenant_queues[tenant_id]
del self.consumer_tags[tenant_id]
logger.info(f"Deleted queues for tenant {tenant_id}")
async def process_input_message(self, message: IncomingMessage) -> None:
@ -156,6 +164,8 @@ class AsyncQueueManager:
await self.shutdown()
return
self.message_count += 1
try:
tenant_id = message.routing_key
@ -187,6 +197,8 @@ class AsyncQueueManager:
await message.nack(requeue=False)
logger.error(f"Error processing input message: {e}", exc_info=True)
raise
finally:
self.message_count -= 1
async def publish_to_output_exchange(self, tenant_id: str, result: Dict[str, Any], headers: Dict[str, Any]) -> None:
await self.output_exchange.publish(
@ -252,6 +264,14 @@ class AsyncQueueManager:
async def shutdown(self) -> None:
logger.info("Shutting down RabbitMQ handler...")
if self.channel:
# Cancel queues to stop fetching messages
logger.debug("Cancelling queues...")
for tenant, queue in self.tenant_queues.items():
await queue.cancel(self.consumer_tags[tenant])
await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"])
while self.message_count != 0:
logger.debug(f"Messages are still being processed: {self.message_count=} ")
time.sleep(2)
await self.channel.close()
if self.connection:
await self.connection.close()

View File

@ -35,6 +35,8 @@ class QueueManager:
self.connection: Union[BlockingConnection, None] = None
self.channel: Union[BlockingChannel, None] = None
self.connection_sleep = settings.rabbitmq.connection_sleep
self.processing_callback = False
self.received_signal = False
atexit.register(self.stop_consuming)
signal.signal(signal.SIGTERM, self._handle_stop_signal)
@ -151,6 +153,7 @@ class QueueManager:
def on_message_callback(channel, method, properties, body):
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
self.processing_callback = True
if method.redelivered:
logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.")
@ -192,9 +195,17 @@ class QueueManager:
channel.basic_nack(method.delivery_tag, requeue=False)
raise
finally:
self.processing_callback = False
if self.received_signal:
self.stop_consuming()
sys.exit(0)
return on_message_callback
def _handle_stop_signal(self, signum, *args, **kwargs):
logger.info(f"Received signal {signum}, stopping consuming...")
self.stop_consuming()
sys.exit(0)
self.received_signal = True
if not self.processing_callback:
self.stop_consuming()
sys.exit(0)

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "pyinfra"
version = "3.0.0"
version = "3.1.0"
description = ""
authors = ["Team Research <research@knecon.com>"]
license = "All rights reseverd"