import asyncio import gzip from operator import itemgetter from typing import Any, Callable, Dict, Set import aiohttp from aio_pika import connect_robust, ExchangeType, Message, IncomingMessage from aio_pika.abc import AbstractIncomingMessage, AbstractChannel, AbstractConnection, AbstractExchange import json from kn_utils.logging import logger from pyinfra.config.loader import load_settings, local_pyinfra_root_path from pyinfra.storage.storages.s3 import get_s3_storage_from_settings from dataclasses import dataclass, field settings = load_settings(local_pyinfra_root_path / "config/") @dataclass class RabbitMQConfig: host: str port: int username: str password: str heartbeat: int input_queue_prefix: str tenant_exchange_name: str service_request_exchange_name: str service_response_exchange_name: str service_dead_letter_queue_name: str queue_expiration_time: int connection_params: Dict[str, object] = field(init=False) def __post_init__(self): self.connection_params = { "host": self.host, "port": self.port, "login": self.username, "password": self.password, "client_properties": {"heartbeat": self.heartbeat}, } class RabbitMQHandler: def __init__( self, config: RabbitMQConfig, tenant_service_url: str, message_processor: Callable[[Dict[str, Any]], Dict[str, Any]], ): self.config = config self.tenant_service_url = tenant_service_url self.message_processor = message_processor self.connection: AbstractConnection | None = None self.channel: AbstractChannel | None = None self.tenant_exchange: AbstractExchange | None = None self.input_exchange: AbstractExchange | None = None self.output_exchange: AbstractExchange | None = None self.tenant_queues: Dict[str, AbstractChannel] = {} async def connect(self) -> None: self.connection = await connect_robust(**self.config.connection_params) self.channel = await self.connection.channel() await self.channel.set_qos(prefetch_count=1) async def setup_exchanges(self): self.tenant_exchange = await self.channel.declare_exchange( self.config.tenant_exchange_name, ExchangeType.TOPIC, durable=True ) self.input_exchange = await self.channel.declare_exchange( self.config.service_request_exchange_name, ExchangeType.DIRECT, durable=True ) self.output_exchange = await self.channel.declare_exchange( self.config.service_response_exchange_name, ExchangeType.DIRECT, durable=True ) async def setup_tenant_queue(self) -> None: queue = await self.channel.declare_queue( "tenant_queue", durable=True, arguments={ "x-dead-letter-exchange": "", "x-dead-letter-routing-key": self.config.service_dead_letter_queue_name, }, ) await queue.bind(self.tenant_exchange, routing_key="tenant.*") await queue.consume(self.process_tenant_message) async def process_tenant_message(self, message: AbstractIncomingMessage) -> None: async with message.process(): message_body = json.loads(message.body.decode()) logger.debug(f"Tenant message received: {message_body}") tenant_id = message_body["tenantId"] routing_key = message.routing_key if routing_key == "tenant.create": await self.create_tenant_queues(tenant_id) elif routing_key == "tenant.delete": await self.delete_tenant_queues(tenant_id) async def create_tenant_queues(self, tenant_id: str) -> None: queue_name = f"{self.config.input_queue_prefix}_{tenant_id}" logger.info(f"Declaring queue: {queue_name}") input_queue = await self.channel.declare_queue( queue_name, durable=True, arguments={ "x-dead-letter-exchange": "", "x-dead-letter-routing-key": self.config.service_dead_letter_queue_name, "x-expires": self.config.queue_expiration_time, "x-max-priority": 2, }, ) await input_queue.bind(self.input_exchange, routing_key=tenant_id) await input_queue.consume(self.process_input_message) self.tenant_queues[tenant_id] = input_queue logger.info(f"Created queues for tenant {tenant_id}") async def delete_tenant_queues(self, tenant_id: str) -> None: if tenant_id in self.tenant_queues: input_queue = self.tenant_queues[tenant_id] await input_queue.delete() del self.tenant_queues[tenant_id] logger.info(f"Deleted queues for tenant {tenant_id}") async def process_input_message(self, message: IncomingMessage) -> None: async with message.process(): try: tenant_id = message.routing_key message_body = json.loads(message.body.decode("utf-8")) filtered_message_headers = ( {k: v for k, v in message.headers.items() if k.lower().startswith("x-")} if message.headers else {} ) logger.debug(f"Processing message with {filtered_message_headers=}.") message_body.update(filtered_message_headers) result = await self.message_processor(message_body) if result: await self.publish_to_output_exchange(tenant_id, result, filtered_message_headers) except json.JSONDecodeError: logger.error(f"Invalid JSON in input message: {message.body}") except FileNotFoundError as e: logger.warning(f"{e}, declining message.") except Exception as e: logger.error(f"Error processing input message: {e}", exc_info=True) raise async def publish_to_output_exchange(self, tenant_id: str, result: Dict[str, Any], headers: Dict[str, Any]) -> None: await self.output_exchange.publish( Message(body=json.dumps(result).encode(), headers=headers), routing_key=tenant_id, ) logger.info(f"Published result to queue {tenant_id}.") async def fetch_active_tenants(self) -> Set[str]: try: async with aiohttp.ClientSession() as session: async with session.get(self.tenant_service_url) as response: if response.status == 200: data = await response.json() return {tenant["tenantId"] for tenant in data} else: logger.error(f"Failed to fetch active tenants. Status: {response.status}") return set() except aiohttp.ClientError as e: logger.error(f"Error fetching active tenants: {e}") return set() async def initialize_tenant_queues(self) -> None: active_tenants = await self.fetch_active_tenants() for tenant_id in active_tenants: await self.create_tenant_queues(tenant_id) async def run(self) -> None: try: await self.connect() await self.setup_exchanges() await self.initialize_tenant_queues() await self.setup_tenant_queue() logger.info("RabbitMQ handler is running. Press CTRL+C to exit.") await asyncio.Future() # Run forever except asyncio.CancelledError: logger.info("Shutting down RabbitMQ handler...") except Exception as e: logger.error(f"An error occurred: {e}", exc_info=True) finally: if self.connection: await self.connection.close() async def dummy_message_processor(message: Dict[str, Any]) -> Dict[str, Any]: logger.info(f"Processing message: {message}") await asyncio.sleep(1) # Simulate processing time storage = get_s3_storage_from_settings(settings) tenant_id, dossier_id, file_id = itemgetter("tenantId", "dossierId", "fileId")(message) suffix = message["responseFileExtension"] object_name = f"{tenant_id}/{dossier_id}/{file_id}.{message['targetFileExtension']}" original_content = json.loads(gzip.decompress(storage.get_object(object_name))) processed_content = { "processedPages": original_content["numberOfPages"], "processedSectionTexts": f"Processed: {original_content['sectionTexts']}", } processed_object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}" processed_data = gzip.compress(json.dumps(processed_content).encode("utf-8")) storage.put_object(processed_object_name, processed_data) processed_message = message.copy() processed_message["processed"] = True processed_message["processor_message"] = "This message was processed by the dummy processor" logger.info(f"Finished processing message. Result: {processed_message}") return processed_message async def test_rabbitmq_handler() -> None: tenant_service_url = settings.storage.tenant_server.endpoint config = RabbitMQConfig( host=settings.rabbitmq.host, port=settings.rabbitmq.port, username=settings.rabbitmq.username, password=settings.rabbitmq.password, heartbeat=settings.rabbitmq.heartbeat, input_queue_prefix=settings.rabbitmq.service_request_queue_prefix, tenant_exchange_name=settings.rabbitmq.service_response_queue_prefix, service_request_exchange_name=settings.rabbitmq.service_request_exchange_name, service_response_exchange_name=settings.rabbitmq.service_response_exchange_name, service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name, queue_expiration_time=settings.rabbitmq.queue_expiration_time, ) handler = RabbitMQHandler(config, tenant_service_url, dummy_message_processor) await handler.connect() await handler.setup_exchanges() await handler.initialize_tenant_queues() await handler.setup_tenant_queue() tenant_id = "test_tenant" # Test tenant creation create_message = {"tenantId": tenant_id} await handler.tenant_exchange.publish( Message(body=json.dumps(create_message).encode()), routing_key="tenant.create" ) logger.info(f"Sent create tenant message for {tenant_id}") await asyncio.sleep(2) # Wait for queue creation # Test service request service_request = { "tenantId": tenant_id, "dossierId": "dossier", "fileId": "file", "targetFileExtension": "json.gz", "responseFileExtension": "result.json.gz", } await handler.input_exchange.publish(Message(body=json.dumps(service_request).encode()), routing_key=tenant_id) logger.info(f"Sent service request for {tenant_id}") await asyncio.sleep(5) # Wait for message processing # Test tenant deletion delete_message = {"tenantId": tenant_id} await handler.tenant_exchange.publish( Message(body=json.dumps(delete_message).encode()), routing_key="tenant.delete" ) logger.info(f"Sent delete tenant message for {tenant_id}") await asyncio.sleep(2) # Wait for queue deletion await handler.connection.close() if __name__ == "__main__": asyncio.run(test_rabbitmq_handler())