diff --git a/scripts/test_async_class.py b/scripts/test_async_class.py new file mode 100644 index 0000000..f9eee3e --- /dev/null +++ b/scripts/test_async_class.py @@ -0,0 +1,166 @@ +import asyncio +import aiohttp +from aio_pika import connect_robust, ExchangeType, Message +from aio_pika.abc import AbstractIncomingMessage +import json +from logging import getLogger +from pyinfra.config.loader import load_settings, local_pyinfra_root_path +import requests + +logger = getLogger(__name__) +logger.setLevel("DEBUG") + +settings = load_settings(local_pyinfra_root_path / "config/") + + +class RabbitMQHandler: + def __init__(self, connection_params, tenant_service_url): + self.connection_params = connection_params + self.tenant_service_url = tenant_service_url + # TODO: remove hardcoded values + self.input_queue_prefix = "service_request_queue" + self.tenant_exchange_name = "tenants-exchange" + self.service_request_exchange_name = "service_request_exchange" # INPUT + self.service_response_exchange_name = "service_response_exchange" # OUTPUT + self.service_dead_letter_queue_name = "service_dlq" + self.connection = None + self.channel = None + self.tenant_exchange = None + self.input_exchange = None + self.output_exchange = None + self.tenant_queues = {} + + async def connect(self): + self.connection = await connect_robust(**self.connection_params) + self.channel = await self.connection.channel() + + # Declare exchanges + self.tenant_exchange = await self.channel.declare_exchange( + self.tenant_exchange_name, ExchangeType.TOPIC, durable=True + ) + self.input_exchange = await self.channel.declare_exchange( + self.service_request_exchange_name, ExchangeType.DIRECT, durable=True + ) + self.output_exchange = await self.channel.declare_exchange( + self.service_response_exchange_name, ExchangeType.DIRECT, durable=True + ) + + async def setup_tenant_queue(self): + queue = await self.channel.declare_queue( + "tenant_queue", + durable=True, + arguments={ + "x-dead-letter-exchange": "", + "x-dead-letter-routing-key": self.service_dead_letter_queue_name, + }, + ) + await queue.bind(self.tenant_exchange, routing_key="tenant.*") + await queue.consume(self.process_tenant_message) + + async def process_tenant_message(self, message: AbstractIncomingMessage): + async with message.process(): + message_body = json.loads(message.body.decode()) + print(message_body) + logger.debug(f"input message: {message_body}") + tenant_id = message_body["queue_name"] + routing_key = message.routing_key + + if routing_key == "tenant.create": + await self.create_tenant_queues(tenant_id) + elif routing_key == "tenant.delete": + await self.delete_tenant_queues(tenant_id) + + async def create_tenant_queues(self, tenant_id): + # Create and bind input queue + queue_name = f"{self.input_queue_prefix}_{tenant_id}" + print(f"queue declared: {queue_name}") + input_queue = await self.channel.declare_queue( + queue_name, + durable=True, + arguments={ + "x-dead-letter-exchange": "", + "x-dead-letter-routing-key": self.service_dead_letter_queue_name, + "x-expires": self.queue_expiration_time, + "x-max-priority": 2, + }, + ) + await input_queue.bind(self.input_exchange, routing_key=tenant_id) + await input_queue.consume(self.process_input_message) + + # Store queues for later use + self.tenant_queues[tenant_id] = input_queue + print(f"Created queues for tenant {tenant_id}") + + async def delete_tenant_queues(self, tenant_id): + if tenant_id in self.tenant_queues: + input_queue = self.tenant_queues[tenant_id] + await input_queue.delete() + del self.tenant_queues[tenant_id] + print(f"Deleted queues for tenant {tenant_id}") + + async def process_input_message(self, message: AbstractIncomingMessage): + async with message.process(): + message_body = json.loads(message.body.decode()) + logger.debug(f"input message: {message_body}") + # Process the incoming message + processed_content = f"Processed: {message_body}" + + # TODO: add additional processing logic here + # ... + + # Publish to the output queue + tenant_id = message.routing_key + await self.output_exchange.publish(Message(body=processed_content.encode()), routing_key=tenant_id) + + # FIXME: coroutine error + async def fetch_active_tenants(self): + async with aiohttp.ClientSession() as session: + async with session.get(self.tenant_service_url) as response: + if response.status == 200 and response.headers["content-type"].lower() == "application/json": + tenants = {await tenant["tenantId"] for tenant in response.json()} + return await tenants + else: + print(f"Failed to fetch active tenants. Status: {response.status}") + return set() + + # TODO: remove after fetch_active_tenants is fixed + def get_initial_tenant_ids(self) -> set: + response = requests.get(self.tenant_service_url, timeout=10) + response.raise_for_status() # Raise an HTTPError for bad responses + + if response.headers["content-type"].lower() == "application/json": + tenants = {tenant["tenantId"] for tenant in response.json()} + return tenants + return set() + + async def initialize_tenant_queues(self): + active_tenants = self.get_initial_tenant_ids() + for tenant_id in active_tenants: + await self.create_tenant_queues(tenant_id) + + async def run(self): + await self.connect() + await self.initialize_tenant_queues() + await self.setup_tenant_queue() + print("RabbitMQ handler is running. Press CTRL+C to exit.") + try: + await asyncio.Future() # Run forever + finally: + await self.connection.close() + + +async def main(): + connection_params = { + "host": settings.rabbitmq.host, + "port": settings.rabbitmq.port, + "login": settings.rabbitmq.username, + "password": settings.rabbitmq.password, + "client_properties": {"heartbeat": settings.rabbitmq.heartbeat}, + } + tenant_service_url = "http://localhost:8080/internal/tenants" + handler = RabbitMQHandler(connection_params, tenant_service_url) + await handler.run() + + +if __name__ == "__main__": + asyncio.run(main())