diff --git a/pyinfra/queue/multiple_tenants.py b/pyinfra/queue/multiple_tenants.py index 5d0b948..a3eac52 100644 --- a/pyinfra/queue/multiple_tenants.py +++ b/pyinfra/queue/multiple_tenants.py @@ -1,15 +1,20 @@ import atexit +import concurrent.futures import pika import os import json import logging import signal import sys +import requests +import time +import pika.exceptions from threading import Thread from dynaconf import Dynaconf from typing import Callable, Union from kn_utils.logging import logger from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection +from pika.adapters.select_connection import SelectConnection from pika.channel import Channel from retry import retry @@ -20,20 +25,22 @@ from pyinfra.config.validators import queue_manager_validators pika_logger = logging.getLogger("pika") pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter +MessageProcessor = Callable[[dict], dict] + class BaseQueueManager: + tenant_ids = [] + def __init__(self, settings: Dynaconf): validate_settings(settings, queue_manager_validators) self.connection_parameters = self.create_connection_parameters(settings) - self.connection: Union[BlockingConnection, None] = None - self.channel: Union[BlockingChannel, None] = None + self.connection = None + self.channel = None self.connection_sleep = settings.rabbitmq.connection_sleep self.queue_expiration_time = settings.rabbitmq.queue_expiration_time self.tenant_exchange_name = settings.rabbitmq.tenant_exchange_name - tenant_ids = [] - atexit.register(self.stop_consuming) signal.signal(signal.SIGTERM, self._handle_stop_signal) signal.signal(signal.SIGINT, self._handle_stop_signal) @@ -49,23 +56,54 @@ class BaseQueueManager: } return pika.ConnectionParameters(**pika_connection_params) - @retry(tries=3, delay=5, jitter=(1, 3), logger=logger) + # @retry(tries=3, delay=5, jitter=(1, 3), logger=logger) + # def establish_connection(self): + # if self.connection and self.connection.is_open: + # logger.debug("Connection to RabbitMQ already established.") + # return + + # logger.info("Establishing connection to RabbitMQ...") + # self.connection = pika.BlockingConnection(parameters=self.connection_parameters) + + # logger.debug("Opening channel...") + # self.channel = self.connection.channel() + # self.channel.basic_qos(prefetch_count=1) + # self.initialize_queues() + + # logger.info("Connection to RabbitMQ established, channel open.") + def establish_connection(self): if self.connection and self.connection.is_open: logger.debug("Connection to RabbitMQ already established.") return logger.info("Establishing connection to RabbitMQ...") - self.connection = pika.BlockingConnection(parameters=self.connection_parameters) + self.connection = SelectConnection(parameters=self.connection_parameters, + on_open_callback=self.on_connection_open, + on_open_error_callback=self.on_connection_open_error, + on_close_callback=self.on_connection_close) + + def on_connection_open(self, unused_connection): + logger.debug("Connection opened") + self.connection.channel(on_open_callback=self.on_channel_open) - logger.debug("Opening channel...") - self.channel = self.connection.channel() + def on_connection_open_error(self, unused_connection, err): + logger.error(f"Connection open failed, reopening in {self.connection_sleep} seconds: {err}") + self.connection.ioloop.call_later(self.connection_sleep, self.establish_connection) + + def on_connection_close(self, unused_connection, reason): + logger.warning(f"Connection closed, reopening in {self.connection_sleep} seconds: {reason}") + self.connection.ioloop.call_later(self.connection_sleep, self.establish_connection) + + def on_channel_open(self, channel): + logger.debug("Channel opened") + self.channel = channel self.channel.basic_qos(prefetch_count=1) self.initialize_queues() - logger.info("Connection to RabbitMQ established, channel open.") - logger.info("Starting to consume messages...") - Thread(target=self.channel.start_consuming).start() + def is_ready(self): + self.establish_connection() + return self.channel.is_open def initialize_queues(self): raise NotImplementedError("Subclasses should implement this method") @@ -94,8 +132,11 @@ class TenantQueueManager(BaseQueueManager): self.tenant_created_queue_name = self.get_tenant_created_queue_name(settings) self.tenant_deleted_queue_name = self.get_tenant_deleted_queue_name(settings) self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings) + self.event_handlers = {"tenant_created": [], "tenant_deleted": []} - self.tenant_ids = [] + TenantQueueManager.tenant_ids = self.get_initial_tenant_ids( + tenant_endpoint_url=settings.storage.tenant_server.endpoint + ) def initialize_queues(self): self.channel.exchange_declare(exchange=self.tenant_exchange_name, exchange_type="topic") @@ -134,6 +175,21 @@ class TenantQueueManager(BaseQueueManager): self.channel.basic_consume(queue=self.tenant_created_queue_name, on_message_callback=self.on_tenant_created) self.channel.basic_consume(queue=self.tenant_deleted_queue_name, on_message_callback=self.on_tenant_deleted) + @retry(tries=3, delay=5, jitter=(1, 3), logger=logger, exceptions=requests.exceptions.HTTPError) + def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> list: + try: + response = requests.get(tenant_endpoint_url, timeout=10) + response.raise_for_status() # Raise an HTTPError for bad responses + + if response.headers["content-type"].lower() == "application/json": + tenant_ids = [tenant["tenantId"] for tenant in response.json()] + else: + logger.warning("Response is not in JSON format.") + except Exception as e: + logger.warning("An unexpected error occurred:", e) + + return tenant_ids + def get_tenant_created_queue_name(self, settings: Dynaconf): return self.get_queue_name_with_suffix( suffix=settings.rabbitmq.tenant_created_event_queue_suffix, pod_name=settings.kubernetes.pod_name @@ -166,9 +222,10 @@ class TenantQueueManager(BaseQueueManager): logger.info(f"Tenant Created: {message}") ch.basic_ack(delivery_tag=method.delivery_tag) - #TODO: replace this w/ working callback - tenant_id = body["tenant_id"] - self.tenant_ids.append(tenant_id) + # TODO: test callback + tenant_id = body["tenantId"] + TenantQueueManager.tenant_ids.append(tenant_id) + self._trigger_event("tenant_created", tenant_id) def on_tenant_deleted(self, ch, method, properties, body): logger.info("Received tenant deleted event") @@ -176,13 +233,38 @@ class TenantQueueManager(BaseQueueManager): logger.info(f"Tenant Deleted: {message}") ch.basic_ack(delivery_tag=method.delivery_tag) - #TODO: replace this w/ working callback - tenant_id = body["tenant_id"] - self.tenant_ids.remove(tenant_id) + # TODO: test callback + tenant_id = body["tenantId"] + TenantQueueManager.tenant_ids.remove(tenant_id) + self._trigger_event("tenant_deleted", tenant_id) + + def _trigger_event(self, event_type, tenant_id): + handler = self.event_handlers.get(event_type) + if handler: + try: + handler(tenant_id) + except Exception as e: + logger.error(f"Error in event handler for {event_type}: {e}", exc_info=True) + + def add_event_handler(self, event_type: str, handler: Callable[[str], None]): + if event_type in self.event_handlers: + self.event_handlers[event_type] = handler + else: + logger.warning(f"Unknown event type: {event_type}") + + def purge_queues(self): + self.establish_connection() + try: + self.channel.queue_purge(self.tenant_created_queue_name) + self.channel.queue_purge(self.tenant_deleted_queue_name) + self.channel.queue_purge(self.tenant_events_dlq_name) + logger.info("Queues purged.") + except pika.exceptions.ChannelWrongStateError: + pass class ServiceQueueManager(BaseQueueManager): - def __init__(self, settings: Dynaconf): + def __init__(self, settings: Dynaconf, tenant_manager: TenantQueueManager): super().__init__(settings) self.service_request_exchange_name = settings.rabbitmq.service_request_exchange_name @@ -190,37 +272,148 @@ class ServiceQueueManager(BaseQueueManager): self.service_queue_prefix = settings.rabbitmq.service_request_queue_prefix self.service_dlq_name = settings.rabbitmq.service_dlq_name + tenant_manager.add_event_handler("tenant_created", self.add_tenant_queue) + tenant_manager.add_event_handler("tenant_deleted", self.delete_tenant_queue) + def initialize_queues(self): - self.channel.exchange_declare(exchange=self.service_request_exchange_name, exchange_type="topic") - queue_name = self.service_queue_prefix + "default" - self.channel.queue_declare(queue=queue_name, arguments={"x-max-priority": 2}) - self.channel.queue_bind(exchange=self.service_request_exchange_name, queue=queue_name) + self.channel.exchange_declare(exchange=self.service_request_exchange_name, exchange_type="direct") + self.channel.exchange_declare(exchange=self.service_response_exchange_name, exchange_type="direct") + + for tenant_id in ServiceQueueManager.tenant_ids: + queue_name = self.service_queue_prefix + "_" + tenant_id + self.channel.queue_declare( + queue=queue_name, + durable=True, + arguments={ + "x-dead-letter-exchange": "", + "x-dead-letter-routing-key": self.service_dlq_name, + "x-expires": self.queue_expiration_time, # TODO: check if necessary + "x-max-priority": 2, + }, + ) + self.channel.queue_bind(queue_name, self.service_request_exchange_name) def start_consuming(self): - self.channel.queue_declare(queue=self.service_queue_prefix + "default") + for tenant_id in ServiceQueueManager.tenant_ids: + queue_name = self.service_queue_prefix + "_" + tenant_id + message_callback = self._make_on_message_callback(message_processor=MessageProcessor, tenant_id=tenant_id) + self.channel.basic_consume( + queue=queue_name, + on_message_callback=message_callback, + ) + logger.info(f"Starting to consume messages for queue {queue_name}...") - self.channel.basic_consume( - queue=self.service_queue_prefix + "default", - on_message_callback=self.react_to_service_request, - auto_ack=True, - ) - - logger.info("Starting to consume messages...") self.channel.start_consuming() + self.connection.ioloop.start() + + def publish_message_to_input_queue( + self, tenant_id: str, message: Union[str, bytes, dict], properties: pika.BasicProperties = None + ): + if isinstance(message, str): + message = message.encode("utf-8") + elif isinstance(message, dict): + message = json.dumps(message).encode("utf-8") + + self.establish_connection() + self.channel.basic_publish( + exchange=self.service_request_exchange_name, + routing_key=tenant_id, + properties=properties, + body=message, + ) + logger.info(f"Published message to queue {tenant_id}.") + + def purge_queues(self): + self.establish_connection() + try: + for tenant_id in ServiceQueueManager.tenant_ids: + queue_name = self.service_queue_prefix + "_" + tenant_id + self.channel.queue_purge(queue_name) + logger.info("Queues purged.") + except pika.exceptions.ChannelWrongStateError: + pass def add_tenant_queue(self, tenant_id: str): queue_name = self.service_queue_prefix + "_" + tenant_id - self.channel.queue_declare(queue_name, durable=True) - self.channel.queue_bind(queue_name, self.service_request_exchange_name) + self.channel.queue_declare( + queue=queue_name, + durable=True, + arguments={ + "x-dead-letter-exchange": "", + "x-dead-letter-routing-key": self.service_dlq_name, + "x-expires": self.queue_expiration_time, # TODO: check if necessary + }, + ) + self.channel.queue_bind(queue=queue_name, exchange=self.service_request_exchange_name) + # TODO: this is likely not possible due to blocking connection + message_callback = self._make_on_message_callback(message_processor=MessageProcessor, tenant_id=tenant_id) + self.channel.basic_consume( + queue=queue_name, + on_message_callback=message_callback, + ) def delete_tenant_queue(self, tenant_id: str): queue_name = self.service_queue_prefix + "_" + tenant_id self.channel.queue_unbind(queue_name, self.service_request_exchange_name) self.channel.queue_delete(queue_name) - def react_to_service_request(self, ch, method, properties, body): - logger.info("Received service request") - message = json.loads(body) - logger.info(f"Service Request: {message}") - ch.basic_ack(delivery_tag=method.delivery_tag) - + def _make_on_message_callback(self, message_processor: MessageProcessor, tenant_id: str): + def process_message_body_and_await_result(unpacked_message_body): + # Processing the message in a separate thread is necessary for the main thread pika client to be able to + # process data events (e.g. heartbeats) while the message is being processed. + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor: + logger.info("Processing payload in separate thread.") + future = thread_pool_executor.submit(message_processor, unpacked_message_body) + + # TODO: This block is probably not necessary, but kept since the implications of removing it are + # unclear. Remove it in a future iteration where less changes are being made to the code base. + # while future.running(): + # logger.debug("Waiting for payload processing to finish...") + # self.connection.sleep(self.connection_sleep) + + return future.result() + + def on_message_callback(channel, method, properties, body): + logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.") + + if method.redelivered: + logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.") + channel.basic_nack(method.delivery_tag, requeue=False) + return + + if body.decode("utf-8") == "STOP": + logger.info("Received stop signal, stopping consuming...") + channel.basic_ack(delivery_tag=method.delivery_tag) + self.stop_consuming() + return + + try: + filtered_message_headers = ( + {k: v for k, v in properties.headers.items() if k.lower().startswith("x-")} + if properties.headers + else {} + ) + logger.debug(f"Processing message with {filtered_message_headers=}.") + result: dict = ( + process_message_body_and_await_result({**json.loads(body), **filtered_message_headers}) or {} + ) + + channel.basic_publish( + exchange=self.service_request_exchange_name, + routing_key=tenant_id, + body=json.dumps(result).encode(), + properties=pika.BasicProperties(headers=filtered_message_headers), + ) + logger.info(f"Published result to queue {tenant_id}.") + + channel.basic_ack(delivery_tag=method.delivery_tag) + logger.debug(f"Message with {method.delivery_tag=} acknowledged.") + except FileNotFoundError as e: + logger.warning(f"{e}, declining message with {method.delivery_tag=}.") + channel.basic_nack(method.delivery_tag, requeue=False) + except Exception: + logger.warning(f"Failed to process message with {method.delivery_tag=}, declining...", exc_info=True) + channel.basic_nack(method.delivery_tag, requeue=False) + raise + + return on_message_callback