Merge branch 'refactor/RES-780-graceful-shutdown' into 'master'
refactor: graceful shutdown See merge request knecon/research/pyinfra!88
This commit is contained in:
commit
fdde56991b
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import signal
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, Set
|
||||
|
||||
@ -11,6 +12,7 @@ from aio_pika.abc import (
|
||||
AbstractConnection,
|
||||
AbstractExchange,
|
||||
AbstractIncomingMessage,
|
||||
AbstractQueue,
|
||||
)
|
||||
from kn_utils.logging import logger
|
||||
from tenacity import (
|
||||
@ -65,7 +67,11 @@ class AsyncQueueManager:
|
||||
self.tenant_exchange: AbstractExchange | None = None
|
||||
self.input_exchange: AbstractExchange | None = None
|
||||
self.output_exchange: AbstractExchange | None = None
|
||||
self.tenant_exchange_queue: AbstractQueue | None = None
|
||||
self.tenant_queues: Dict[str, AbstractChannel] = {}
|
||||
self.consumer_tags: Dict[str, str] = {}
|
||||
|
||||
self.message_count: int = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.connection = await connect_robust(**self.config.connection_params)
|
||||
@ -88,7 +94,7 @@ class AsyncQueueManager:
|
||||
)
|
||||
|
||||
async def setup_tenant_queue(self) -> None:
|
||||
queue = await self.channel.declare_queue(
|
||||
self.tenant_exchange_queue = await self.channel.declare_queue(
|
||||
f"{self.config.pod_name}_{self.config.tenant_event_queue_suffix}",
|
||||
durable=True,
|
||||
arguments={
|
||||
@ -98,8 +104,10 @@ class AsyncQueueManager:
|
||||
"x-max-priority": 2,
|
||||
},
|
||||
)
|
||||
await queue.bind(self.tenant_exchange, routing_key="tenant.*")
|
||||
await queue.consume(self.process_tenant_message)
|
||||
await self.tenant_exchange_queue.bind(self.tenant_exchange, routing_key="tenant.*")
|
||||
self.consumer_tags["tenant_exchange_queue"] = await self.tenant_exchange_queue.consume(
|
||||
self.process_tenant_message
|
||||
)
|
||||
|
||||
async def process_tenant_message(self, message: AbstractIncomingMessage) -> None:
|
||||
async with message.process():
|
||||
@ -127,8 +135,7 @@ class AsyncQueueManager:
|
||||
},
|
||||
)
|
||||
await input_queue.bind(self.input_exchange, routing_key=tenant_id)
|
||||
await input_queue.consume(self.process_input_message)
|
||||
|
||||
self.consumer_tags[tenant_id] = await input_queue.consume(self.process_input_message)
|
||||
self.tenant_queues[tenant_id] = input_queue
|
||||
logger.info(f"Created queues for tenant {tenant_id}")
|
||||
|
||||
@ -137,6 +144,7 @@ class AsyncQueueManager:
|
||||
# somehow queue.delete() does not work here
|
||||
await self.channel.queue_delete(f"{self.config.input_queue_prefix}_{tenant_id}")
|
||||
del self.tenant_queues[tenant_id]
|
||||
del self.consumer_tags[tenant_id]
|
||||
logger.info(f"Deleted queues for tenant {tenant_id}")
|
||||
|
||||
async def process_input_message(self, message: IncomingMessage) -> None:
|
||||
@ -156,6 +164,8 @@ class AsyncQueueManager:
|
||||
await self.shutdown()
|
||||
return
|
||||
|
||||
self.message_count += 1
|
||||
|
||||
try:
|
||||
tenant_id = message.routing_key
|
||||
|
||||
@ -187,6 +197,8 @@ class AsyncQueueManager:
|
||||
await message.nack(requeue=False)
|
||||
logger.error(f"Error processing input message: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
self.message_count -= 1
|
||||
|
||||
async def publish_to_output_exchange(self, tenant_id: str, result: Dict[str, Any], headers: Dict[str, Any]) -> None:
|
||||
await self.output_exchange.publish(
|
||||
@ -252,6 +264,14 @@ class AsyncQueueManager:
|
||||
async def shutdown(self) -> None:
|
||||
logger.info("Shutting down RabbitMQ handler...")
|
||||
if self.channel:
|
||||
# Cancel queues to stop fetching messages
|
||||
logger.debug("Cancelling queues...")
|
||||
for tenant, queue in self.tenant_queues.items():
|
||||
await queue.cancel(self.consumer_tags[tenant])
|
||||
await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"])
|
||||
while self.message_count != 0:
|
||||
logger.debug(f"Messages are still being processed: {self.message_count=} ")
|
||||
time.sleep(2)
|
||||
await self.channel.close()
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
|
||||
@ -35,6 +35,8 @@ class QueueManager:
|
||||
self.connection: Union[BlockingConnection, None] = None
|
||||
self.channel: Union[BlockingChannel, None] = None
|
||||
self.connection_sleep = settings.rabbitmq.connection_sleep
|
||||
self.processing_callback = False
|
||||
self.received_signal = False
|
||||
|
||||
atexit.register(self.stop_consuming)
|
||||
signal.signal(signal.SIGTERM, self._handle_stop_signal)
|
||||
@ -151,6 +153,7 @@ class QueueManager:
|
||||
|
||||
def on_message_callback(channel, method, properties, body):
|
||||
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
|
||||
self.processing_callback = True
|
||||
|
||||
if method.redelivered:
|
||||
logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.")
|
||||
@ -192,9 +195,17 @@ class QueueManager:
|
||||
channel.basic_nack(method.delivery_tag, requeue=False)
|
||||
raise
|
||||
|
||||
finally:
|
||||
self.processing_callback = False
|
||||
if self.received_signal:
|
||||
self.stop_consuming()
|
||||
sys.exit(0)
|
||||
|
||||
return on_message_callback
|
||||
|
||||
def _handle_stop_signal(self, signum, *args, **kwargs):
|
||||
logger.info(f"Received signal {signum}, stopping consuming...")
|
||||
self.stop_consuming()
|
||||
sys.exit(0)
|
||||
self.received_signal = True
|
||||
if not self.processing_callback:
|
||||
self.stop_consuming()
|
||||
sys.exit(0)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "pyinfra"
|
||||
version = "3.0.0"
|
||||
version = "3.1.0"
|
||||
description = ""
|
||||
authors = ["Team Research <research@knecon.com>"]
|
||||
license = "All rights reseverd"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user