import asyncio import gzip from operator import itemgetter from typing import Any, Callable, Dict, Set import aiohttp from aio_pika import connect_robust, ExchangeType, Message, IncomingMessage from aio_pika.abc import AbstractIncomingMessage import json from logging import getLogger from pyinfra.config.loader import load_settings, local_pyinfra_root_path from pyinfra.storage.storages.s3 import get_s3_storage_from_settings logger = getLogger(__name__) logger.setLevel("DEBUG") settings = load_settings(local_pyinfra_root_path / "config/") class RabbitMQHandler: def __init__(self, connection_params, tenant_service_url, message_processor): self.connection_params = connection_params self.tenant_service_url = tenant_service_url # TODO: remove hardcoded values self.input_queue_prefix = "service_request_queue" self.tenant_exchange_name = "tenants-exchange" self.service_request_exchange_name = "service_request_exchange" # INPUT self.service_response_exchange_name = "service_response_exchange" # OUTPUT self.service_dead_letter_queue_name = "service_dlq" self.queue_expiration_time = 300000 self.message_processor = message_processor self.connection = None self.channel = None self.tenant_exchange = None self.input_exchange = None self.output_exchange = None self.tenant_queues = {} async def connect(self): self.connection = await connect_robust(**self.connection_params) self.channel = await self.connection.channel() # Declare exchanges self.tenant_exchange = await self.channel.declare_exchange( self.tenant_exchange_name, ExchangeType.TOPIC, durable=True ) self.input_exchange = await self.channel.declare_exchange( self.service_request_exchange_name, ExchangeType.DIRECT, durable=True ) self.output_exchange = await self.channel.declare_exchange( self.service_response_exchange_name, ExchangeType.DIRECT, durable=True ) async def setup_tenant_queue(self): queue = await self.channel.declare_queue( "tenant_queue", durable=True, arguments={ "x-dead-letter-exchange": "", "x-dead-letter-routing-key": self.service_dead_letter_queue_name, }, ) await queue.bind(self.tenant_exchange, routing_key="tenant.*") await queue.consume(self.process_tenant_message) async def process_tenant_message(self, message: AbstractIncomingMessage): async with message.process(): message_body = json.loads(message.body.decode()) print(message_body) logger.debug(f"input message: {message_body}") tenant_id = message_body["tenantId"] routing_key = message.routing_key if routing_key == "tenant.create": await self.create_tenant_queues(tenant_id) elif routing_key == "tenant.delete": await self.delete_tenant_queues(tenant_id) async def create_tenant_queues(self, tenant_id): # Create and bind input queue queue_name = f"{self.input_queue_prefix}_{tenant_id}" print(f"queue declared: {queue_name}") input_queue = await self.channel.declare_queue( queue_name, durable=True, arguments={ "x-dead-letter-exchange": "", "x-dead-letter-routing-key": self.service_dead_letter_queue_name, "x-expires": self.queue_expiration_time, "x-max-priority": 2, }, ) await input_queue.bind(self.input_exchange, routing_key=tenant_id) await input_queue.consume(lambda msg: self.process_input_message(msg, self.message_processor)) # Store queues for later use self.tenant_queues[tenant_id] = input_queue print(f"Created queues for tenant {tenant_id}") async def delete_tenant_queues(self, tenant_id): if tenant_id in self.tenant_queues: input_queue = self.tenant_queues[tenant_id] await input_queue.delete() del self.tenant_queues[tenant_id] print(f"Deleted queues for tenant {tenant_id}") async def process_input_message( self, message: IncomingMessage, message_processor: Callable[[Dict[str, Any]], Dict[str, Any]] ) -> None: async with message.process(): try: tenant_id = message.routing_key message_body = json.loads(message.body.decode("utf-8")) # Extract headers filtered_message_headers = ( {k: v for k, v in message.headers.items() if k.lower().startswith("x-")} if message.headers else {} ) logger.debug(f"Processing message with {filtered_message_headers=}.") # Process the message message_body.update(filtered_message_headers) # Run the message processor in a separate thread to avoid blocking loop = asyncio.get_running_loop() result = await loop.run_in_executor(None, message_processor, message_body) if result: # Publish the result to the output exchange self.publish_to_output_exchange(tenant_id, result, filtered_message_headers) except json.JSONDecodeError: logger.error(f"Invalid JSON in input message: {message.body}") # Message will be nacked and sent to dead-letter queue except FileNotFoundError as e: logger.warning(f"{e}, declining message.") # Message will be nacked and sent to dead-letter queue except Exception as e: logger.error(f"Error processing input message: {e}", exc_info=True) # Message will be nacked and sent to dead-letter queue raise async def publish_to_output_exchange(self, tenant_id: str, result: Dict[str, Any], headers: Dict[str, Any]): await self.output_exchange.publish( Message(body=json.dumps(result).encode(), headers=headers), routing_key=tenant_id, ) logger.info(f"Published result to queue {tenant_id}.") async def fetch_active_tenants(self) -> Set[str]: try: async with aiohttp.ClientSession() as session: async with session.get(self.tenant_service_url) as response: if response.status == 200: data = await response.json() return {tenant["tenantId"] for tenant in data} else: logger.error(f"Failed to fetch active tenants. Status: {response.status}") return set() except aiohttp.ClientError as e: logger.error(f"Error fetching active tenants: {e}") return set() async def initialize_tenant_queues(self): active_tenants = await self.fetch_active_tenants() for tenant_id in active_tenants: await self.create_tenant_queues(tenant_id) async def run(self): await self.connect() await self.initialize_tenant_queues() await self.setup_tenant_queue() print("RabbitMQ handler is running. Press CTRL+C to exit.") try: await asyncio.Future() # Run forever finally: await self.connection.close() async def main(): connection_params = { "host": settings.rabbitmq.host, "port": settings.rabbitmq.port, "login": settings.rabbitmq.username, "password": settings.rabbitmq.password, "client_properties": {"heartbeat": settings.rabbitmq.heartbeat}, } tenant_service_url = "http://localhost:8080/internal/tenants" handler = RabbitMQHandler(connection_params, tenant_service_url, dummy_message_processor) await handler.run() ######################################################################## def upload_json_and_make_message_body(tenant_id: str) -> Dict[str, Any]: dossier_id, file_id, suffix = "dossier", "file", "json.gz" content = { "numberOfPages": 7, "sectionTexts": "data", } object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}" data = gzip.compress(json.dumps(content).encode("utf-8")) storage = get_s3_storage_from_settings(settings) if not storage.has_bucket(): storage.make_bucket() storage.put_object(object_name, data) message_body = { "tenantId": tenant_id, "dossierId": dossier_id, "fileId": file_id, "targetFileExtension": suffix, "responseFileExtension": f"result.{suffix}", } return message_body def tenant_event_message(tenant_id: str) -> Dict[str, str]: return {"tenantId": tenant_id} async def dummy_message_processor(message: Dict[str, Any]) -> Dict[str, Any]: logger.info(f"Processing message: {message}") await asyncio.sleep(1) # Simulate processing time storage = get_s3_storage_from_settings(settings) tenant_id, dossier_id, file_id = itemgetter("tenantId", "dossierId", "fileId")(message) suffix = message["responseFileExtension"] # Simulate processing by modifying the original content object_name = f"{tenant_id}/{dossier_id}/{file_id}.{message['targetFileExtension']}" original_content = json.loads(gzip.decompress(storage.get_object(object_name))) processed_content = { "processedPages": original_content["numberOfPages"], "processedSectionTexts": f"Processed: {original_content['sectionTexts']}", } # Save processed content processed_object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}" processed_data = gzip.compress(json.dumps(processed_content).encode("utf-8")) storage.put_object(processed_object_name, processed_data) processed_message = message.copy() processed_message["processed"] = True processed_message["processor_message"] = "This message was processed by the dummy processor" logger.info(f"Finished processing message. Result: {processed_message}") return processed_message async def test_rabbitmq_handler(): connection_params = { "host": settings.rabbitmq.host, "port": settings.rabbitmq.port, "login": settings.rabbitmq.username, "password": settings.rabbitmq.password, "client_properties": {"heartbeat": settings.rabbitmq.heartbeat}, } tenant_service_url = "http://localhost:8080/internal/tenants" handler = RabbitMQHandler(connection_params, tenant_service_url, dummy_message_processor) await handler.connect() await handler.setup_tenant_queue() # Test tenant creation tenant_id = "test_tenant" create_message = tenant_event_message(tenant_id) await handler.tenant_exchange.publish( Message(body=json.dumps(create_message).encode()), routing_key="tenant.create" ) logger.info(f"Sent create tenant message for {tenant_id}") # Wait for tenant queue creation await asyncio.sleep(2) # Test service request service_request = upload_json_and_make_message_body(tenant_id) await handler.input_exchange.publish(Message(body=json.dumps(service_request).encode()), routing_key=tenant_id) logger.info(f"Sent service request for {tenant_id}") # Wait for message processing await asyncio.sleep(5) # Test tenant deletion delete_message = tenant_event_message(tenant_id) await handler.tenant_exchange.publish( Message(body=json.dumps(delete_message).encode()), routing_key="tenant.delete" ) logger.info(f"Sent delete tenant message for {tenant_id}") # Wait for tenant queue deletion await asyncio.sleep(2) await handler.connection.close() if __name__ == "__main__": asyncio.run(test_rabbitmq_handler())