diff --git a/scripts/test_async_class.py b/scripts/test_async_class.py index 3ac82e7..95af16b 100644 --- a/scripts/test_async_class.py +++ b/scripts/test_async_class.py @@ -1,7 +1,7 @@ import asyncio -from typing import Set +from typing import Any, Callable, Dict, Set import aiohttp -from aio_pika import connect_robust, ExchangeType, Message +from aio_pika import connect_robust, ExchangeType, Message, IncomingMessage from aio_pika.abc import AbstractIncomingMessage import json from logging import getLogger @@ -15,7 +15,7 @@ settings = load_settings(local_pyinfra_root_path / "config/") class RabbitMQHandler: - def __init__(self, connection_params, tenant_service_url): + def __init__(self, connection_params, tenant_service_url, message_processor): self.connection_params = connection_params self.tenant_service_url = tenant_service_url # TODO: remove hardcoded values @@ -25,6 +25,7 @@ class RabbitMQHandler: self.service_response_exchange_name = "service_response_exchange" # OUTPUT self.service_dead_letter_queue_name = "service_dlq" self.queue_expiration_time = 300000 + self.message_processor = message_processor self.connection = None self.channel = None self.tenant_exchange = None @@ -87,7 +88,7 @@ class RabbitMQHandler: }, ) await input_queue.bind(self.input_exchange, routing_key=tenant_id) - await input_queue.consume(self.process_input_message) + await input_queue.consume(lambda msg: self.process_input_message(msg, self.message_processor)) # Store queues for later use self.tenant_queues[tenant_id] = input_queue @@ -100,19 +101,46 @@ class RabbitMQHandler: del self.tenant_queues[tenant_id] print(f"Deleted queues for tenant {tenant_id}") - async def process_input_message(self, message: AbstractIncomingMessage): + async def process_input_message( + self, message: IncomingMessage, message_processor: Callable[[Dict[str, Any]], Dict[str, Any]] + ) -> None: async with message.process(): - message_body = json.loads(message.body.decode()) - logger.debug(f"input message: {message_body}") - # Process the incoming message - processed_content = f"Processed: {message_body}" + try: + tenant_id = message.routing_key + message_body = json.loads(message.body.decode("utf-8")) - # TODO: add additional processing logic here - # ... + # Extract headers + filtered_message_headers = ( + {k: v for k, v in message.headers.items() if k.lower().startswith("x-")} if message.headers else {} + ) - # Publish to the output queue - tenant_id = message.routing_key - await self.output_exchange.publish(Message(body=processed_content.encode()), routing_key=tenant_id) + logger.debug(f"Processing message with {filtered_message_headers=}.") + + # Process the message + message_body.update(filtered_message_headers) + + # Run the message processor in a separate thread to avoid blocking + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, message_processor, message_body) + + if result: + # Publish the result to the output exchange + await self.output_exchange.publish( + Message(body=json.dumps(result).encode(), headers=filtered_message_headers), + routing_key=tenant_id, + ) + logger.info(f"Published result to queue {tenant_id}.") + + except json.JSONDecodeError: + logger.error(f"Invalid JSON in input message: {message.body}") + # Message will be nacked and sent to dead-letter queue + except FileNotFoundError as e: + logger.warning(f"{e}, declining message.") + # Message will be nacked and sent to dead-letter queue + except Exception as e: + logger.error(f"Error processing input message: {e}", exc_info=True) + # Message will be nacked and sent to dead-letter queue + raise async def fetch_active_tenants(self) -> Set[str]: try: