330 lines
14 KiB
Python
330 lines
14 KiB
Python
import asyncio
|
|
import concurrent.futures
|
|
import json
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, Dict, Set
|
|
|
|
import aiohttp
|
|
from aio_pika import ExchangeType, IncomingMessage, Message, connect
|
|
from aio_pika.abc import (
|
|
AbstractChannel,
|
|
AbstractConnection,
|
|
AbstractExchange,
|
|
AbstractIncomingMessage,
|
|
AbstractQueue,
|
|
)
|
|
from aio_pika.exceptions import (
|
|
ChannelClosed,
|
|
ChannelInvalidStateError,
|
|
ConnectionClosed,
|
|
)
|
|
from aiormq.exceptions import AMQPConnectionError
|
|
from kn_utils.logging import logger
|
|
from kn_utils.retry import retry
|
|
|
|
|
|
@dataclass
|
|
class RabbitMQConfig:
|
|
host: str
|
|
port: int
|
|
username: str
|
|
password: str
|
|
heartbeat: int
|
|
input_queue_prefix: str
|
|
tenant_event_queue_suffix: str
|
|
tenant_exchange_name: str
|
|
service_request_exchange_name: str
|
|
service_response_exchange_name: str
|
|
service_dead_letter_queue_name: str
|
|
queue_expiration_time: int
|
|
pod_name: str
|
|
|
|
connection_params: Dict[str, object] = field(init=False)
|
|
|
|
def __post_init__(self):
|
|
self.connection_params = {
|
|
"host": self.host,
|
|
"port": self.port,
|
|
"login": self.username,
|
|
"password": self.password,
|
|
"client_properties": {"heartbeat": self.heartbeat},
|
|
}
|
|
|
|
|
|
class AsyncQueueManager:
|
|
def __init__(
|
|
self,
|
|
config: RabbitMQConfig,
|
|
tenant_service_url: str,
|
|
message_processor: Callable[[Dict[str, Any]], Dict[str, Any]],
|
|
max_concurrent_tasks: int = 10,
|
|
):
|
|
self.config = config
|
|
self.tenant_service_url = tenant_service_url
|
|
self.message_processor = message_processor
|
|
self.semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
|
|
|
self.connection: AbstractConnection | None = None
|
|
self.channel: AbstractChannel | None = None
|
|
self.tenant_exchange: AbstractExchange | None = None
|
|
self.input_exchange: AbstractExchange | None = None
|
|
self.output_exchange: AbstractExchange | None = None
|
|
self.tenant_exchange_queue: AbstractQueue | None = None
|
|
self.tenant_queues: Dict[str, AbstractChannel] = {}
|
|
self.consumer_tags: Dict[str, str] = {}
|
|
|
|
self.message_count: int = 0
|
|
|
|
@retry(tries=5, exceptions=AMQPConnectionError, reraise=True, logger=logger)
|
|
async def connect(self) -> None:
|
|
logger.info("Attempting to connect to RabbitMQ...")
|
|
self.connection = await connect(**self.config.connection_params)
|
|
self.connection.close_callbacks.add(self.on_connection_close)
|
|
self.channel = await self.connection.channel()
|
|
await self.channel.set_qos(prefetch_count=1)
|
|
logger.info("Successfully connected to RabbitMQ")
|
|
|
|
async def on_connection_close(self, sender, exc):
|
|
"""This is a callback for unexpected connection closures."""
|
|
logger.debug(f"Sender: {sender}")
|
|
if isinstance(exc, ConnectionClosed):
|
|
logger.warning("Connection to RabbitMQ lost. Attempting to reconnect...")
|
|
try:
|
|
active_tenants = await self.fetch_active_tenants()
|
|
await self.run(active_tenants=active_tenants)
|
|
logger.debug("Reconnected to RabbitMQ successfully")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to reconnect to RabbitMQ: {e}")
|
|
# cancel queue manager and webserver to shutdown service
|
|
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
|
[task.cancel() for task in tasks if task.get_name() in ["queues", "webserver"]]
|
|
else:
|
|
logger.debug("Connection closed on purpose.")
|
|
|
|
async def is_ready(self) -> bool:
|
|
if self.connection is None or self.connection.is_closed:
|
|
try:
|
|
await self.connect()
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to RabbitMQ: {e}")
|
|
return False
|
|
return True
|
|
|
|
@retry(tries=5, exceptions=(AMQPConnectionError, ChannelInvalidStateError), reraise=True, logger=logger)
|
|
async def setup_exchanges(self) -> None:
|
|
self.tenant_exchange = await self.channel.declare_exchange(
|
|
self.config.tenant_exchange_name, ExchangeType.TOPIC, durable=True
|
|
)
|
|
self.input_exchange = await self.channel.declare_exchange(
|
|
self.config.service_request_exchange_name, ExchangeType.DIRECT, durable=True
|
|
)
|
|
self.output_exchange = await self.channel.declare_exchange(
|
|
self.config.service_response_exchange_name, ExchangeType.DIRECT, durable=True
|
|
)
|
|
|
|
# we must declare DLQ to handle error messages
|
|
self.dead_letter_queue = await self.channel.declare_queue(
|
|
self.config.service_dead_letter_queue_name, durable=True
|
|
)
|
|
|
|
@retry(tries=5, exceptions=(AMQPConnectionError, ChannelInvalidStateError), reraise=True, logger=logger)
|
|
async def setup_tenant_queue(self) -> None:
|
|
self.tenant_exchange_queue = await self.channel.declare_queue(
|
|
f"{self.config.pod_name}_{self.config.tenant_event_queue_suffix}",
|
|
durable=True,
|
|
arguments={
|
|
"x-dead-letter-exchange": "",
|
|
"x-dead-letter-routing-key": self.config.service_dead_letter_queue_name,
|
|
"x-expires": self.config.queue_expiration_time,
|
|
},
|
|
)
|
|
await self.tenant_exchange_queue.bind(self.tenant_exchange, routing_key="tenant.*")
|
|
self.consumer_tags["tenant_exchange_queue"] = await self.tenant_exchange_queue.consume(
|
|
self.process_tenant_message
|
|
)
|
|
|
|
async def process_tenant_message(self, message: AbstractIncomingMessage) -> None:
|
|
try:
|
|
async with message.process():
|
|
message_body = json.loads(message.body.decode())
|
|
logger.debug(f"Tenant message received: {message_body}")
|
|
tenant_id = message_body["tenantId"]
|
|
routing_key = message.routing_key
|
|
|
|
if routing_key == "tenant.created":
|
|
await self.create_tenant_queues(tenant_id)
|
|
elif routing_key == "tenant.delete":
|
|
await self.delete_tenant_queues(tenant_id)
|
|
except Exception as e:
|
|
logger.error(e, exc_info=True)
|
|
|
|
async def create_tenant_queues(self, tenant_id: str) -> None:
|
|
queue_name = f"{self.config.input_queue_prefix}_{tenant_id}"
|
|
logger.info(f"Declaring queue: {queue_name}")
|
|
try:
|
|
input_queue = await self.channel.declare_queue(
|
|
queue_name,
|
|
durable=True,
|
|
arguments={
|
|
"x-dead-letter-exchange": "",
|
|
"x-dead-letter-routing-key": self.config.service_dead_letter_queue_name,
|
|
},
|
|
)
|
|
await input_queue.bind(self.input_exchange, routing_key=tenant_id)
|
|
self.consumer_tags[tenant_id] = await input_queue.consume(self.process_input_message)
|
|
self.tenant_queues[tenant_id] = input_queue
|
|
logger.info(f"Created and started consuming queue for tenant {tenant_id}")
|
|
except Exception as e:
|
|
logger.error(e, exc_info=True)
|
|
|
|
async def delete_tenant_queues(self, tenant_id: str) -> None:
|
|
if tenant_id in self.tenant_queues:
|
|
# somehow queue.delete() does not work here
|
|
await self.channel.queue_delete(f"{self.config.input_queue_prefix}_{tenant_id}")
|
|
del self.tenant_queues[tenant_id]
|
|
del self.consumer_tags[tenant_id]
|
|
logger.info(f"Deleted queues for tenant {tenant_id}")
|
|
|
|
async def process_input_message(self, message: IncomingMessage) -> None:
|
|
async def process_message_body_and_await_result(unpacked_message_body):
|
|
async with self.semaphore:
|
|
loop = asyncio.get_running_loop()
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
|
logger.info("Processing payload in a separate thread.")
|
|
result = await loop.run_in_executor(
|
|
thread_pool_executor, self.message_processor, unpacked_message_body
|
|
)
|
|
return result
|
|
|
|
async with message.process(ignore_processed=True):
|
|
if message.redelivered:
|
|
logger.warning(f"Declining message with {message.delivery_tag=} due to it being redelivered.")
|
|
await message.nack(requeue=False)
|
|
return
|
|
|
|
if message.body.decode("utf-8") == "STOP":
|
|
logger.info("Received stop signal, stopping consumption...")
|
|
await message.ack()
|
|
# TODO: shutdown is probably not the right call here - align w/ Dev what should happen on stop signal
|
|
await self.shutdown()
|
|
return
|
|
|
|
self.message_count += 1
|
|
|
|
try:
|
|
tenant_id = message.routing_key
|
|
|
|
filtered_message_headers = (
|
|
{k: v for k, v in message.headers.items() if k.lower().startswith("x-")} if message.headers else {}
|
|
)
|
|
|
|
logger.debug(f"Processing message with {filtered_message_headers=}.")
|
|
|
|
result: dict = await (
|
|
process_message_body_and_await_result({**json.loads(message.body), **filtered_message_headers})
|
|
or {}
|
|
)
|
|
|
|
if result:
|
|
await self.publish_to_output_exchange(tenant_id, result, filtered_message_headers)
|
|
await message.ack()
|
|
logger.debug(f"Message with {message.delivery_tag=} acknowledged.")
|
|
else:
|
|
raise ValueError(f"Could not process message with {message.body=}.")
|
|
|
|
except json.JSONDecodeError:
|
|
await message.nack(requeue=False)
|
|
logger.error(f"Invalid JSON in input message: {message.body}", exc_info=True)
|
|
except FileNotFoundError as e:
|
|
logger.warning(f"{e}, declining message with {message.delivery_tag=}.", exc_info=True)
|
|
await message.nack(requeue=False)
|
|
except Exception as e:
|
|
await message.nack(requeue=False)
|
|
logger.error(f"Error processing input message: {e}", exc_info=True)
|
|
finally:
|
|
self.message_count -= 1
|
|
|
|
async def publish_to_output_exchange(self, tenant_id: str, result: Dict[str, Any], headers: Dict[str, Any]) -> None:
|
|
await self.output_exchange.publish(
|
|
Message(body=json.dumps(result).encode(), headers=headers),
|
|
routing_key=tenant_id,
|
|
)
|
|
logger.info(f"Published result to queue {tenant_id}.")
|
|
|
|
@retry(tries=5, exceptions=(aiohttp.ClientResponseError, aiohttp.ClientConnectorError), reraise=True, logger=logger)
|
|
async def fetch_active_tenants(self) -> Set[str]:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(self.tenant_service_url) as response:
|
|
response.raise_for_status()
|
|
if response.headers["content-type"].lower() == "application/json":
|
|
data = await response.json()
|
|
return {tenant["tenantId"] for tenant in data}
|
|
else:
|
|
logger.error(
|
|
f"Failed to fetch active tenants. Content type is not JSON: {response.headers['content-type'].lower()}"
|
|
)
|
|
return set()
|
|
|
|
@retry(
|
|
tries=5,
|
|
exceptions=(
|
|
AMQPConnectionError,
|
|
ChannelInvalidStateError,
|
|
),
|
|
reraise=True,
|
|
logger=logger,
|
|
)
|
|
async def initialize_tenant_queues(self, active_tenants: set) -> None:
|
|
for tenant_id in active_tenants:
|
|
await self.create_tenant_queues(tenant_id)
|
|
|
|
async def run(self, active_tenants: set) -> None:
|
|
|
|
await self.connect()
|
|
await self.setup_exchanges()
|
|
await self.initialize_tenant_queues(active_tenants=active_tenants)
|
|
await self.setup_tenant_queue()
|
|
|
|
logger.info("RabbitMQ handler is running. Press CTRL+C to exit.")
|
|
|
|
async def close_channels(self) -> None:
|
|
try:
|
|
if self.channel and not self.channel.is_closed:
|
|
# Cancel queues to stop fetching messages
|
|
logger.debug("Cancelling queues...")
|
|
for tenant, queue in self.tenant_queues.items():
|
|
await queue.cancel(self.consumer_tags[tenant])
|
|
if self.tenant_exchange_queue:
|
|
await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"])
|
|
while self.message_count != 0:
|
|
logger.debug(f"Messages are still being processed: {self.message_count=} ")
|
|
await asyncio.sleep(2)
|
|
await self.channel.close(exc=asyncio.CancelledError)
|
|
logger.debug("Channel closed.")
|
|
else:
|
|
logger.debug("No channel to close.")
|
|
except ChannelClosed:
|
|
logger.warning("Channel was already closed.")
|
|
except ConnectionClosed:
|
|
logger.warning("Connection was lost, unable to close channel.")
|
|
except Exception as e:
|
|
logger.error(f"Error during channel shutdown: {e}")
|
|
|
|
async def close_connection(self) -> None:
|
|
try:
|
|
if self.connection and not self.connection.is_closed:
|
|
await self.connection.close(exc=asyncio.CancelledError)
|
|
logger.debug("Connection closed.")
|
|
else:
|
|
logger.debug("No connection to close.")
|
|
except ConnectionClosed:
|
|
logger.warning("Connection was already closed.")
|
|
except Exception as e:
|
|
logger.error(f"Error closing connection: {e}")
|
|
|
|
async def shutdown(self) -> None:
|
|
logger.info("Shutting down RabbitMQ handler...")
|
|
await self.close_channels()
|
|
await self.close_connection()
|
|
logger.info("RabbitMQ handler shut down successfully.")
|