diff --git a/pyinfra/examples.py b/pyinfra/examples.py index 9c21d9e..b1f7f48 100644 --- a/pyinfra/examples.py +++ b/pyinfra/examples.py @@ -5,7 +5,7 @@ from fastapi import FastAPI from kn_utils.logging import logger from pyinfra.config.loader import get_pyinfra_validators, validate_settings -from pyinfra.queue.async_tenants_v2 import RabbitMQConfig, RabbitMQHandler +from pyinfra.queue.async_manager import AsyncQueueManager, RabbitMQConfig from pyinfra.queue.callback import Callback from pyinfra.queue.manager import QueueManager from pyinfra.utils.opentelemetry import instrument_app, instrument_pika, setup_trace @@ -27,11 +27,13 @@ def get_rabbitmq_config(settings: Dynaconf): password=settings.rabbitmq.password, heartbeat=settings.rabbitmq.heartbeat, input_queue_prefix=settings.rabbitmq.service_request_queue_prefix, + tenant_event_queue_suffix=settings.rabbitmq.tenant_event_queue_suffix, tenant_exchange_name=settings.rabbitmq.service_response_queue_prefix, service_request_exchange_name=settings.rabbitmq.service_request_exchange_name, service_response_exchange_name=settings.rabbitmq.service_response_exchange_name, service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name, queue_expiration_time=settings.rabbitmq.queue_expiration_time, + pod_name=settings.kubernetes.pod_name, ) @@ -66,7 +68,7 @@ def start_standard_queue_consumer( if settings.concurrency.enabled: config = get_rabbitmq_config(settings) - manager = RabbitMQHandler( + manager = AsyncQueueManager( config=config, tenant_service_url=settings.storage.tenant_server.endpoint, message_processor=callback ) else: @@ -77,7 +79,7 @@ def start_standard_queue_consumer( webserver_thread = create_webserver_thread_from_settings(app, settings) webserver_thread.start() - if isinstance(manager, RabbitMQHandler): + if isinstance(manager, AsyncQueueManager): asyncio.run(manager.run()) elif isinstance(manager, QueueManager): manager.start_consuming(callback) diff --git a/pyinfra/queue/async_tenants_v2.py b/pyinfra/queue/async_manager.py similarity index 94% rename from pyinfra/queue/async_tenants_v2.py rename to pyinfra/queue/async_manager.py index 52f1cff..715869e 100644 --- a/pyinfra/queue/async_tenants_v2.py +++ b/pyinfra/queue/async_manager.py @@ -23,11 +23,13 @@ class RabbitMQConfig: password: str heartbeat: int input_queue_prefix: str + tenant_event_queue_suffix: str tenant_exchange_name: str service_request_exchange_name: str service_response_exchange_name: str service_dead_letter_queue_name: str queue_expiration_time: int + pod_name: str connection_params: Dict[str, object] = field(init=False) @@ -41,7 +43,7 @@ class RabbitMQConfig: } -class RabbitMQHandler: +class AsyncQueueManager: def __init__( self, config: RabbitMQConfig, @@ -80,9 +82,8 @@ class RabbitMQHandler: ) async def setup_tenant_queue(self) -> None: - # TODO: Add k8s pod_name to tenant queue name - add DLQ? queue = await self.channel.declare_queue( - "tenant_queue", + f"{self.config.pod_name}_{self.config.tenant_event_queue_suffix}", durable=True, arguments={ "x-dead-letter-exchange": "", @@ -190,13 +191,14 @@ class RabbitMQHandler: try: async with aiohttp.ClientSession() as session: async with session.get(self.tenant_service_url) as response: - # TODO: dont know if we should check for 200, could also be 2xx - # maybe handle bad requests with response.raise_for_status() - if response.status == 200: + response.raise_for_status() + if response.headers["content-type"].lower() == "application/json": data = await response.json() return {tenant["tenantId"] for tenant in data} else: - logger.error(f"Failed to fetch active tenants. Status: {response.status}") + logger.error( + f"Failed to fetch active tenants. Content type is not JSON: {response.headers['content-type'].lower()}" + ) return set() except aiohttp.ClientError as e: logger.error(f"Error fetching active tenants: {e}") @@ -227,7 +229,7 @@ class RabbitMQHandler: logger.info("RabbitMQ handler is running. Press CTRL+C to exit.") await stop.wait() # Run until stop signal received except asyncio.CancelledError: - logger.info("Operation cancelled.") + logger.warning("Operation cancelled.") except Exception as e: logger.error(f"An error occurred: {e}", exc_info=True) finally: @@ -240,5 +242,3 @@ class RabbitMQHandler: if self.connection: await self.connection.close() logger.info("RabbitMQ handler shut down successfully.") - - # TODO: purge_queues diff --git a/pyinfra/queue/async_tenants.py b/pyinfra/queue/async_tenants.py deleted file mode 100644 index 9bdb3a6..0000000 --- a/pyinfra/queue/async_tenants.py +++ /dev/null @@ -1,435 +0,0 @@ -import asyncio -import concurrent.futures -import datetime -import json -import time -import uuid -from typing import Callable, Union - -import aio_pika -import aiormq -import requests -from aio_pika import DeliveryMode, Message -from aio_pika.abc import AbstractIncomingMessage -from dynaconf import Dynaconf -from kn_utils.logging import logger - -from pyinfra.config.loader import ( - load_settings, - local_pyinfra_root_path, - validate_settings, -) -from pyinfra.config.validators import queue_manager_validators - -MessageProcessor = Callable[[dict], dict] - - -class AsyncQueueManager: - - def __init__(self, settings: Dynaconf, message_processor: Callable = None) -> None: - validate_settings(settings, queue_manager_validators) - - self.message_processor = message_processor - self.connection_params = self.get_connection_params(settings) - self.connection = None - self.channel = None - - self.active_tenants = self.get_initial_tenant_ids(tenant_endpoint_url=settings.storage.tenant_server.endpoint) - self.consumer_tasks = {} - - self.connection_sleep = settings.rabbitmq.connection_sleep - self.queue_expiration_time = settings.rabbitmq.queue_expiration_time - - self.tenant_created_queue_name = self.get_tenant_created_queue_name(settings) - self.tenant_deleted_queue_name = self.get_tenant_deleted_queue_name(settings) - self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings) - self.tenant_exchange_name = settings.rabbitmq.tenant_exchange_name - - self.service_request_exchange_name = settings.rabbitmq.service_request_exchange_name - self.service_response_exchange_name = settings.rabbitmq.service_response_exchange_name - - self.service_request_queue_prefix = settings.rabbitmq.service_request_queue_prefix - self.service_response_queue_prefix = settings.rabbitmq.service_response_queue_prefix - - self.service_dlq_name = settings.rabbitmq.service_dlq_name - - @staticmethod - def get_connection_params(settings: Dynaconf): - return { - "host": settings.rabbitmq.host, - "port": settings.rabbitmq.port, - "login": settings.rabbitmq.username, - "password": settings.rabbitmq.password, - "client_properties": {"heartbeat": settings.rabbitmq.heartbeat}, - } - - def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> set: - response = requests.get(tenant_endpoint_url, timeout=10) - response.raise_for_status() # Raise an HTTPError for bad responses - - if response.headers["content-type"].lower() == "application/json": - tenants = {tenant["tenantId"] for tenant in response.json()} - return tenants - return set() - - def get_tenant_created_queue_name(self, settings: Dynaconf) -> str: - return self.get_queue_name_with_suffix( - suffix=settings.rabbitmq.tenant_created_event_queue_suffix, pod_name=settings.kubernetes.pod_name - ) - - def get_tenant_deleted_queue_name(self, settings: Dynaconf) -> str: - return self.get_queue_name_with_suffix( - suffix=settings.rabbitmq.tenant_deleted_event_queue_suffix, pod_name=settings.kubernetes.pod_name - ) - - def get_tenant_events_dlq_name(self, settings: Dynaconf) -> str: - return self.get_queue_name_with_suffix( - suffix=settings.rabbitmq.tenant_event_dlq_suffix, pod_name=settings.kubernetes.pod_name - ) - - def get_queue_name_with_suffix(self, suffix: str, pod_name: str) -> str: - if not self.use_default_queue_name() and pod_name: - return f"{pod_name}{suffix}" - return self.get_default_queue_name() - - def use_default_queue_name(self) -> bool: - return False - - def get_default_queue_name(self): - raise NotImplementedError("Queue name method not implemented") - - async def is_ready(self) -> bool: - await self.connect() - return self.channel.is_open - - #### ASYNC STUFF - async def purge_queues(self) -> None: - await self.establish_connection() - try: - for tenant_id in self.active_tenants: - service_request_queue = await self.channel.get_queue(f"{self.service_request_queue_prefix}_{tenant_id}") - await service_request_queue.purge() - service_response_queue = await self.channel.get_queue( - f"{self.service_response_queue_prefix}_{tenant_id}" - ) - await service_response_queue.purge() - logger.info("Queues purged.") - except aio_pika.exceptions.ChannelInvalidStateError: - pass - - async def connect(self): - self.connection = await aio_pika.connect_robust(**self.connection_params) - self.channel = await self.connection.channel() - logger.info("Connection established.") - - async def establish_connection(self): - await self.connect() - await self.initialize_queues() - logger.info("Queues initialized.") - # await self.start_processing() - - async def start_processing(self): - await self.establish_connection() - tenant_events = asyncio.create_task(self.handle_tenant_events()) - service_events = asyncio.create_task(self.start_consumers()) - - await asyncio.gather(tenant_events, service_events) - - async def initialize_queues(self): - await self.channel.set_qos(prefetch_count=1) - - service_request_exchange = await self.channel.declare_exchange( - name=self.service_request_exchange_name, type=aio_pika.ExchangeType.DIRECT, durable=True - ) - service_response_exchange = await self.channel.declare_exchange( - name=self.service_response_exchange_name, type=aio_pika.ExchangeType.DIRECT, durable=True - ) - - for tenant_id in self.active_tenants: - request_queue_name = f"{self.service_request_queue_prefix}_{tenant_id}" - request_queue = await self.channel.declare_queue( - name=request_queue_name, - durable=True, - arguments={ - "x-dead-letter-exchange": "", - "x-dead-letter-routing-key": self.service_dlq_name, - "x-expires": self.queue_expiration_time, # TODO: check if necessary - "x-max-priority": 2, - }, - ) - await request_queue.bind(exchange=service_request_exchange, routing_key=tenant_id) - - response_queue_name = f"{self.service_response_queue_prefix}_{tenant_id}" - response_queue = await self.channel.declare_queue( - name=response_queue_name, - durable=True, - arguments={ - "x-dead-letter-exchange": "", - "x-dead-letter-routing-key": self.service_dlq_name, - "x-expires": self.queue_expiration_time, # TODO: check if necessary - }, - ) - await response_queue.bind(exchange=service_response_exchange, routing_key=tenant_id) - - async def handle_tenant_events(self): - # Declare the topic exchange for tenant events - exchange = await self.channel.declare_exchange( - self.tenant_exchange_name, aio_pika.ExchangeType.TOPIC, durable=True - ) - # Declare a queue for receiving tenant events - queue = await self.channel.declare_queue( - "tenant_events_queue", - arguments={ - "x-dead-letter-exchange": "", - "x-dead-letter-routing-key": self.tenant_events_dlq_name, - }, - durable=True, - ) - - await queue.bind(exchange, routing_key="tenant.*") - - async with queue.iterator() as queue_iter: - async for message in queue_iter: - async with message.process(reject_on_redelivered=True): - routing_key = message.routing_key - message_body = json.loads(message.body.decode()) - tenant_id = message_body["tenantId"] - if routing_key == "tenant.created": - # Handle tenant creation - await self.handle_tenant_created(tenant_id) - - elif routing_key == "tenant.deleted": - # Handle tenant deletion - await self.handle_tenant_deleted(tenant_id) - else: - message.nack() - continue - message.ack() - await self.restart_consumers() - - async def handle_tenant_created(self, tenant_id): - # Handle creation of input and output queues for the new tenant - await self.create_tenant_queues(tenant_id) - await self.restart_consumers() - - async def handle_tenant_deleted(self, tenant_id): - # Handle deletion of input and output queues for the tenant - await self.delete_tenant_queues(tenant_id) - await self.restart_consumers() - - async def create_tenant_queues(self, tenant_id): - # Implement queue creation logic for the tenant - queue_name = f"{self.service_request_queue_prefix}_{tenant_id}" - queue = await self.channel.declare_queue( - name=queue_name, - durable=True, - arguments={ - "x-dead-letter-exchange": "", - "x-dead-letter-routing-key": self.service_dlq_name, - "x-expires": self.queue_expiration_time, # TODO: check if necessary - }, - ) - exchange = await self.channel.get_exchange(self.service_request_exchange_name) - await queue.bind(exchange=exchange, routing_key=tenant_id) - self.active_tenants.add(tenant_id) - logger.info(f"Created queue for tenant {tenant_id}") - - async def delete_tenant_queues(self, tenant_id): - queue_name = f"{self.service_request_queue_prefix}_{tenant_id}" - queue = await self.channel.get_queue(queue_name) - exchange = await self.channel.get_exchange(self.service_request_exchange_name) - await queue.unbind(exchange=exchange, routing_key=tenant_id) - await self.channel.queue_delete(queue_name) - self.active_tenants.discard(tenant_id) - - async def consume_from_request_queue(self, tenant_id): - queue_name = f"{self.service_request_queue_prefix}_{tenant_id}" - queue = await self.channel.get_queue(queue_name) - - async with queue.iterator() as queue_iter: - async for message in queue_iter: - async with message.process(): - on_message_callback = await self._make_on_message_callback(self.message_processor, tenant_id) - await on_message_callback(message) - - async def publish_to_service_response_queue(self, tenant_id, result): - service_response_exchange = await self.channel.get_exchange(self.service_response_exchange_name) - - await service_response_exchange.publish( - Message( - body=json.dumps(result).encode(), - delivery_mode=DeliveryMode.NOT_PERSISTENT, - timestamp=datetime.datetime.now(), - message_id=str(uuid.uuid4()), - ), - routing_key=tenant_id, - ) - - async def restart_consumers(self): - # Stop current consumers and start new ones for active tenants - await self.stop_consumers() - await self.start_consumers() - - async def start_consumers(self): - # Start consuming messages from input queues for active tenants - for tenant_id in self.active_tenants: - if tenant_id not in self.consumer_tasks: - self.consumer_tasks[tenant_id] = asyncio.create_task(self.consume_from_request_queue(tenant_id)) - - consumer_tasks = [self.consume_from_request_queue(tenant) for tenant in self.active_tenants] - await asyncio.gather(*consumer_tasks) - - async def stop_consumers(self): - for task in self.consumer_tasks.values(): - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - self.consumer_tasks.clear() - - async def main_loop(self): - await self.establish_connection() - - async def shutdown(self): - # Implement cleanup logic - await self.stop_consumers() - if self.connection: - await self.connection.close() - - async def _make_on_message_callback(self, message_processor: MessageProcessor, tenant_id: str) -> Callable: - async def process_message_body_and_await_result(unpacked_message_body): - # Processing the message in a separate thread is necessary for the main thread pika client to be able to - # process data events (e.g. heartbeats) while the message is being processed. - # with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor: - # logger.info("Processing payload in separate thread.") - # future = thread_pool_executor.submit(message_processor, unpacked_message_body) - - # return future.result() - - return {"result": "lovely"} - - async def on_message_callback(message: AbstractIncomingMessage): - logger.info(f"Received message from queue with delivery_tag {message.delivery_tag}.") - - if message.redelivered: - logger.warning(f"Declining message with {message.delivery_tag=} due to it being redelivered.") - await message.nack(requeue=False) - return - - if message.body.decode("utf-8") == "STOP": - logger.info("Received stop signal, stopping consuming...") - await message.ack() - await self.stop_consumers() - return - - try: - filtered_message_headers = ( - {k: v for k, v in message.properties.headers.items() if k.lower().startswith("x-")} - if message.properties.headers - else {} - ) - logger.debug(f"Processing message with {filtered_message_headers=}.") - result: dict = await ( - process_message_body_and_await_result({**json.loads(message.body), **filtered_message_headers}) - or {} - ) - - await self.publish_to_service_response_queue(tenant_id, result) - logger.info(f"Published result to queue {tenant_id}.") - - await message.ack() - logger.debug(f"Message with {message.delivery_tag=} acknowledged.") - except FileNotFoundError as e: - logger.warning(f"{e}, declining message with {message.delivery_tag=}.") - await message.nack(requeue=False) - except Exception as e: - logger.warning(f"Failed to process message with {message.delivery_tag=}, declining...", exc_info=True) - logger.warning(e) - await message.nack(requeue=False) - raise - - return on_message_callback - - async def publish_message_to_input_queue(self, tenant_id: str, message: Union[str, bytes, dict]) -> None: - if isinstance(message, str): - message = message.encode("utf-8") - elif isinstance(message, dict): - message = json.dumps(message).encode("utf-8") - - await self.establish_connection() - - service_request_exchange = await self.channel.get_exchange(self.service_request_exchange_name) - - await service_request_exchange.publish( - message=Message( - body=message, - delivery_mode=DeliveryMode.NOT_PERSISTENT, - timestamp=datetime.datetime.now(), - message_id=str(uuid.uuid4()), - ), - routing_key=tenant_id, - ) - - logger.info(f"Published message to queue {tenant_id}.") - - async def publish_message_to_tenant_created_queue(self, message: Union[str, bytes, dict]) -> None: - if isinstance(message, str): - message = message.encode("utf-8") - elif isinstance(message, dict): - message = json.dumps(message).encode("utf-8") - - await self.establish_connection() - service_request_exchange = await self.channel.get_exchange(self.tenant_exchange_name) - - await service_request_exchange.publish( - message=Message( - body=message, - delivery_mode=DeliveryMode.NOT_PERSISTENT, - timestamp=datetime.datetime.now(), - message_id=str(uuid.uuid4()), - ), - routing_key="tenant.created", - ) - - logger.info(f"Published message to queue {self.tenant_created_queue_name}.") - - async def publish_message_to_tenant_deleted_queue(self, message: Union[str, bytes, dict]) -> None: - if isinstance(message, str): - message = message.encode("utf-8") - elif isinstance(message, dict): - message = json.dumps(message).encode("utf-8") - - await self.establish_connection() - service_request_exchange = await self.channel.get_exchange(self.tenant_exchange_name) - - await service_request_exchange.publish( - message=Message( - body=message, - delivery_mode=DeliveryMode.NOT_PERSISTENT, - timestamp=datetime.datetime.now(), - message_id=str(uuid.uuid4()), - ), - routing_key="tenant.delete", - ) - - logger.info(f"Published message to queue {self.tenant_deleted_queue_name}.") - - -async def main() -> None: - import time - - settings = load_settings(local_pyinfra_root_path / "config/") - callback = "" - - manager = AsyncQueueManager(settings=settings, message_processor=callback) - - await manager.main_loop() - - while True: - time.sleep(100) - print("keep idling") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/scripts/send_async_request.py b/scripts/send_async_request.py index 1c095d3..4b44cf4 100644 --- a/scripts/send_async_request.py +++ b/scripts/send_async_request.py @@ -9,7 +9,8 @@ from aio_pika.abc import AbstractIncomingMessage from kn_utils.logging import logger from pyinfra.config.loader import load_settings, local_pyinfra_root_path -from pyinfra.queue.async_tenants_v2 import RabbitMQConfig, RabbitMQHandler +from pyinfra.examples import get_rabbitmq_config +from pyinfra.queue.async_manager import AsyncQueueManager from pyinfra.storage.storages.s3 import S3Storage, get_s3_storage_from_settings settings = load_settings(local_pyinfra_root_path / "config/") @@ -88,29 +89,12 @@ def upload_json_and_make_message_body(tenant_id: str): async def test_rabbitmq_handler() -> None: tenant_service_url = settings.storage.tenant_server.endpoint - config = RabbitMQConfig( - host=settings.rabbitmq.host, - port=settings.rabbitmq.port, - username=settings.rabbitmq.username, - password=settings.rabbitmq.password, - heartbeat=settings.rabbitmq.heartbeat, - input_queue_prefix=settings.rabbitmq.service_request_queue_prefix, - tenant_exchange_name=settings.rabbitmq.service_response_queue_prefix, - service_request_exchange_name=settings.rabbitmq.service_request_exchange_name, - service_response_exchange_name=settings.rabbitmq.service_response_exchange_name, - service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name, - queue_expiration_time=settings.rabbitmq.queue_expiration_time, - ) + config = get_rabbitmq_config(settings) - handler = RabbitMQHandler(config, tenant_service_url, dummy_message_processor) + handler = AsyncQueueManager(config, tenant_service_url, dummy_message_processor) await handler.connect() await handler.setup_exchanges() - # await handler.initialize_tenant_queues() - # await handler.setup_tenant_queue() - - # for queue in handler.tenant_queues.values(): - # await queue.purge() tenant_id = "test_tenant"