diff --git a/pyinfra/examples.py b/pyinfra/examples.py index 6b62f98..a3fb9be 100644 --- a/pyinfra/examples.py +++ b/pyinfra/examples.py @@ -4,7 +4,8 @@ from kn_utils.logging import logger from pyinfra.config.loader import get_pyinfra_validators, validate_settings from pyinfra.queue.callback import Callback -from pyinfra.queue.manager import QueueManager +# from pyinfra.queue.manager import QueueManager +from pyinfra.queue.sequential_tenants import QueueManager from pyinfra.utils.opentelemetry import instrument_pika, setup_trace, instrument_app from pyinfra.webserver.prometheus import ( add_prometheus_endpoint, @@ -52,4 +53,5 @@ def start_standard_queue_consumer( webserver_thread = create_webserver_thread_from_settings(app, settings) webserver_thread.start() - queue_manager.start_consuming(callback) \ No newline at end of file + # queue_manager.start_consuming(callback) + queue_manager.start_sequential_consume(callback) \ No newline at end of file diff --git a/pyinfra/queue/multiple_tenants.py b/pyinfra/queue/multiple_tenants.py index a3eac52..156736f 100644 --- a/pyinfra/queue/multiple_tenants.py +++ b/pyinfra/queue/multiple_tenants.py @@ -1,19 +1,16 @@ import atexit +import asyncio import concurrent.futures import pika -import os import json import logging import signal import sys import requests -import time import pika.exceptions -from threading import Thread from dynaconf import Dynaconf from typing import Callable, Union from kn_utils.logging import logger -from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection from pika.adapters.select_connection import SelectConnection from pika.channel import Channel from retry import retry @@ -21,7 +18,6 @@ from retry import retry from pyinfra.config.loader import validate_settings from pyinfra.config.validators import queue_manager_validators - pika_logger = logging.getLogger("pika") pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter @@ -72,19 +68,37 @@ class BaseQueueManager: # logger.info("Connection to RabbitMQ established, channel open.") + @retry(tries=3, delay=5, jitter=(1, 3), logger=logger) def establish_connection(self): if self.connection and self.connection.is_open: logger.debug("Connection to RabbitMQ already established.") return logger.info("Establishing connection to RabbitMQ...") - self.connection = SelectConnection(parameters=self.connection_parameters, + return SelectConnection(parameters=self.connection_parameters, on_open_callback=self.on_connection_open, on_open_error_callback=self.on_connection_open_error, on_close_callback=self.on_connection_close) + def close_connection(self): + # self._consuming = False + if self.connection.is_closing or self.connection.is_closed: + logger.info('Connection is closing or already closed') + else: + logger.info('Closing connection') + self.connection.close() + def on_connection_open(self, unused_connection): logger.debug("Connection opened") + self.open_channel() + + def open_channel(self): + """Open a new channel with RabbitMQ by issuing the Channel.Open RPC + command. When RabbitMQ responds that the channel is open, the + on_channel_open callback will be invoked by pika. + + """ + logger.debug('Creating a new channel') self.connection.channel(on_open_callback=self.on_channel_open) def on_connection_open_error(self, unused_connection, err): @@ -98,9 +112,33 @@ class BaseQueueManager: def on_channel_open(self, channel): logger.debug("Channel opened") self.channel = channel + # self.add_on_channel_close_callback() self.channel.basic_qos(prefetch_count=1) self.initialize_queues() + + # def add_on_channel_close_callback(self): + # """This method tells pika to call the on_channel_closed method if + # RabbitMQ unexpectedly closes the channel. + + # """ + # logger.debug('Adding channel close callback') + # self.channel.add_on_close_callback(self.on_channel_closed) + + # def on_channel_closed(self, channel, reason): + # """Invoked by pika when RabbitMQ unexpectedly closes the channel. + # Channels are usually closed if you attempt to do something that + # violates the protocol, such as re-declare an exchange or queue with + # different parameters. In this case, we'll close the connection + # to shutdown the object. + + # :param pika.channel.Channel: The closed channel + # :param Exception reason: why the channel was closed + + # """ + # logger.warning('Channel %i was closed: %s', channel, reason) + # self.close_connection() + def is_ready(self): self.establish_connection() return self.channel.is_open @@ -134,7 +172,7 @@ class TenantQueueManager(BaseQueueManager): self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings) self.event_handlers = {"tenant_created": [], "tenant_deleted": []} - TenantQueueManager.tenant_ids = self.get_initial_tenant_ids( + self.get_initial_tenant_ids( tenant_endpoint_url=settings.storage.tenant_server.endpoint ) @@ -146,7 +184,6 @@ class TenantQueueManager(BaseQueueManager): arguments={ "x-dead-letter-exchange": "", "x-dead-letter-routing-key": self.tenant_events_dlq_name, - "x-expires": self.queue_expiration_time, }, durable=True, ) @@ -155,13 +192,11 @@ class TenantQueueManager(BaseQueueManager): arguments={ "x-dead-letter-exchange": "", "x-dead-letter-routing-key": self.tenant_events_dlq_name, - "x-expires": self.queue_expiration_time, }, durable=True, ) self.channel.queue_declare( queue=self.tenant_events_dlq_name, - arguments={"x-expires": self.queue_expiration_time}, durable=True, ) @@ -175,6 +210,11 @@ class TenantQueueManager(BaseQueueManager): self.channel.basic_consume(queue=self.tenant_created_queue_name, on_message_callback=self.on_tenant_created) self.channel.basic_consume(queue=self.tenant_deleted_queue_name, on_message_callback=self.on_tenant_deleted) + def start(self): + self.connection = self.establish_connection() + if self.connection is not None: + self.connection.ioloop.start() + @retry(tries=3, delay=5, jitter=(1, 3), logger=logger, exceptions=requests.exceptions.HTTPError) def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> list: try: @@ -182,13 +222,13 @@ class TenantQueueManager(BaseQueueManager): response.raise_for_status() # Raise an HTTPError for bad responses if response.headers["content-type"].lower() == "application/json": - tenant_ids = [tenant["tenantId"] for tenant in response.json()] + tenants = [tenant["tenantId"] for tenant in response.json()] else: logger.warning("Response is not in JSON format.") except Exception as e: logger.warning("An unexpected error occurred:", e) - return tenant_ids + self.tenant_ids.extend(tenants) def get_tenant_created_queue_name(self, settings: Dynaconf): return self.get_queue_name_with_suffix( @@ -224,10 +264,10 @@ class TenantQueueManager(BaseQueueManager): # TODO: test callback tenant_id = body["tenantId"] - TenantQueueManager.tenant_ids.append(tenant_id) + self.tenant_ids.append(tenant_id) self._trigger_event("tenant_created", tenant_id) - def on_tenant_deleted(self, ch, method, properties, body): + def on_tenant_deleted(self, ch: Channel, method, properties, body): logger.info("Received tenant deleted event") message = json.loads(body) logger.info(f"Tenant Deleted: {message}") @@ -235,7 +275,7 @@ class TenantQueueManager(BaseQueueManager): # TODO: test callback tenant_id = body["tenantId"] - TenantQueueManager.tenant_ids.remove(tenant_id) + self.tenant_ids.remove(tenant_id) self._trigger_event("tenant_deleted", tenant_id) def _trigger_event(self, event_type, tenant_id): @@ -279,7 +319,7 @@ class ServiceQueueManager(BaseQueueManager): self.channel.exchange_declare(exchange=self.service_request_exchange_name, exchange_type="direct") self.channel.exchange_declare(exchange=self.service_response_exchange_name, exchange_type="direct") - for tenant_id in ServiceQueueManager.tenant_ids: + for tenant_id in self.tenant_ids: queue_name = self.service_queue_prefix + "_" + tenant_id self.channel.queue_declare( queue=queue_name, @@ -293,18 +333,22 @@ class ServiceQueueManager(BaseQueueManager): ) self.channel.queue_bind(queue_name, self.service_request_exchange_name) - def start_consuming(self): - for tenant_id in ServiceQueueManager.tenant_ids: + def start_consuming(self, message_processor: Callable): + self.connection = self.establish_connection() + for tenant_id in self.tenant_ids: queue_name = self.service_queue_prefix + "_" + tenant_id - message_callback = self._make_on_message_callback(message_processor=MessageProcessor, tenant_id=tenant_id) + message_callback = self._make_on_message_callback(message_processor=message_processor, tenant_id=tenant_id) self.channel.basic_consume( queue=queue_name, on_message_callback=message_callback, ) logger.info(f"Starting to consume messages for queue {queue_name}...") - self.channel.start_consuming() - self.connection.ioloop.start() + # self.channel.start_consuming() + if self.connection is not None: + self.connection.ioloop.start() + else: + logger.info("Connection is None, cannot start ioloop") def publish_message_to_input_queue( self, tenant_id: str, message: Union[str, bytes, dict], properties: pika.BasicProperties = None @@ -326,7 +370,7 @@ class ServiceQueueManager(BaseQueueManager): def purge_queues(self): self.establish_connection() try: - for tenant_id in ServiceQueueManager.tenant_ids: + for tenant_id in self.tenant_ids: queue_name = self.service_queue_prefix + "_" + tenant_id self.channel.queue_purge(queue_name) logger.info("Queues purged.") @@ -371,7 +415,10 @@ class ServiceQueueManager(BaseQueueManager): # logger.debug("Waiting for payload processing to finish...") # self.connection.sleep(self.connection_sleep) - return future.result() + loop = asyncio.get_event_loop() + return loop.run_in_executor(None, future.result) + + # return future.result() def on_message_callback(channel, method, properties, body): logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.") @@ -416,4 +463,4 @@ class ServiceQueueManager(BaseQueueManager): channel.basic_nack(method.delivery_tag, requeue=False) raise - return on_message_callback + return on_message_callback \ No newline at end of file diff --git a/pyinfra/queue/sequential_tenants.py b/pyinfra/queue/sequential_tenants.py new file mode 100644 index 0000000..e0514a2 --- /dev/null +++ b/pyinfra/queue/sequential_tenants.py @@ -0,0 +1,342 @@ +import atexit +import concurrent.futures +import json +import logging +import requests +import signal +import sys +import time +from typing import Callable, Union + +import pika +import pika.exceptions +from dynaconf import Dynaconf +from kn_utils.logging import logger +from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection +from retry import retry + +from pyinfra.config.loader import validate_settings +from pyinfra.config.validators import queue_manager_validators + +logger.set_level("DEBUG") +pika_logger = logging.getLogger("pika") +pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter + +MessageProcessor = Callable[[dict], dict] + + +class QueueManager: + def __init__(self, settings: Dynaconf): + validate_settings(settings, queue_manager_validators) + + self.tenant_created_queue_name = self.get_tenant_created_queue_name(settings) + self.tenant_deleted_queue_name = self.get_tenant_deleted_queue_name(settings) + self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings) + + self.connection_sleep = settings.rabbitmq.connection_sleep + self.queue_expiration_time = settings.rabbitmq.queue_expiration_time + + self.tenant_exchange_name = settings.rabbitmq.tenant_exchange_name + self.service_request_exchange_name = settings.rabbitmq.service_request_exchange_name + self.service_response_exchange_name = settings.rabbitmq.service_response_exchange_name + + self.service_queue_prefix = settings.rabbitmq.service_request_queue_prefix + self.service_dlq_name = settings.rabbitmq.service_dlq_name + + self.connection_parameters = self.create_connection_parameters(settings) + + self.connection: Union[BlockingConnection, None] = None + self.channel: Union[BlockingChannel, None] = None + + self.tenant_ids = self.get_initial_tenant_ids(tenant_endpoint_url=settings.storage.tenant_server.endpoint) + + self._consuming = False + + atexit.register(self.stop_consuming) + signal.signal(signal.SIGTERM, self._handle_stop_signal) + signal.signal(signal.SIGINT, self._handle_stop_signal) + + @staticmethod + def create_connection_parameters(settings: Dynaconf): + credentials = pika.PlainCredentials(username=settings.rabbitmq.username, password=settings.rabbitmq.password) + pika_connection_params = { + "host": settings.rabbitmq.host, + "port": settings.rabbitmq.port, + "credentials": credentials, + "heartbeat": settings.rabbitmq.heartbeat, + } + + return pika.ConnectionParameters(**pika_connection_params) + + @retry(tries=3, delay=5, jitter=(1, 3), logger=logger, exceptions=requests.exceptions.HTTPError) + def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> list: + try: + response = requests.get(tenant_endpoint_url, timeout=10) + response.raise_for_status() # Raise an HTTPError for bad responses + + if response.headers["content-type"].lower() == "application/json": + tenants = [tenant["tenantId"] for tenant in response.json()] + else: + logger.warning("Response is not in JSON format.") + except Exception as e: + logger.warning("An unexpected error occurred:", e) + + return tenants + + def get_tenant_created_queue_name(self, settings: Dynaconf): + return self.get_queue_name_with_suffix( + suffix=settings.rabbitmq.tenant_created_event_queue_suffix, pod_name=settings.kubernetes.pod_name + ) + + def get_tenant_deleted_queue_name(self, settings: Dynaconf): + return self.get_queue_name_with_suffix( + suffix=settings.rabbitmq.tenant_deleted_event_queue_suffix, pod_name=settings.kubernetes.pod_name + ) + + def get_tenant_events_dlq_name(self, settings: Dynaconf): + return self.get_queue_name_with_suffix( + suffix=settings.rabbitmq.tenant_event_dlq_suffix, pod_name=settings.kubernetes.pod_name + ) + + def get_queue_name_with_suffix(self, suffix: str, pod_name: str): + if not self.use_default_queue_name() and pod_name: + return f"{pod_name}{suffix}" + return self.get_default_queue_name() + + def use_default_queue_name(self): + return False + + def get_default_queue_name(self): + raise NotImplementedError("Queue name method not implemented") + + @retry(tries=3, delay=5, jitter=(1, 3), logger=logger) + def establish_connection(self): + # TODO: set sensible retry parameters + if self.connection and self.connection.is_open: + logger.debug("Connection to RabbitMQ already established.") + return + + logger.info("Establishing connection to RabbitMQ...") + self.connection = pika.BlockingConnection(parameters=self.connection_parameters) + + logger.debug("Opening channel...") + self.channel = self.connection.channel() + self.channel.basic_qos(prefetch_count=1) + + args = { + "x-dead-letter-exchange": "", + "x-dead-letter-routing-key": self.tenant_events_dlq_name, + } + + ### Declare exchanges for tenants and responses + self.channel.exchange_declare(exchange=self.tenant_exchange_name, exchange_type="topic") + self.channel.exchange_declare(exchange=self.service_request_exchange_name, exchange_type="direct") + self.channel.exchange_declare(exchange=self.service_response_exchange_name, exchange_type="direct") + + self.channel.queue_declare(self.tenant_created_queue_name, arguments=args, auto_delete=False, durable=True) + self.channel.queue_declare(self.tenant_deleted_queue_name, arguments=args, auto_delete=False, durable=True) + + self.channel.queue_bind( + exchange=self.tenant_exchange_name, queue=self.tenant_created_queue_name, routing_key="tenant.created" + ) + self.channel.queue_bind( + exchange=self.tenant_exchange_name, queue=self.tenant_deleted_queue_name, routing_key="tenant.delete" + ) + + for tenant_id in self.tenant_ids: + queue_name = self.service_queue_prefix + "_" + tenant_id + self.channel.queue_declare( + queue=queue_name, + durable=True, + arguments={ + "x-dead-letter-exchange": "", + "x-dead-letter-routing-key": self.service_dlq_name, + "x-expires": self.queue_expiration_time, # TODO: check if necessary + "x-max-priority": 2, + }, + ) + self.channel.queue_bind( + queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id + ) + + logger.info("Connection to RabbitMQ established, channel open.") + + def is_ready(self): + self.establish_connection() + return self.channel.is_open + + @retry(exceptions=pika.exceptions.AMQPConnectionError, tries=3, delay=5, jitter=(1, 3), logger=logger) + def start_sequential_consume(self, message_processor: Callable): + + self.establish_connection() + self._consuming = True + + try: + while self._consuming: + for tenant_id in self.tenant_ids: + queue_name = self.service_queue_prefix + "_" + tenant_id + method_frame, properties, body = self.channel.basic_get(queue_name) + if method_frame: + on_message_callback = self._make_on_message_callback(message_processor, tenant_id) + on_message_callback(self.channel, method_frame, properties, body) + else: + logger.debug("No message returned") + time.sleep(self.connection_sleep) + + ### Handle tenant events + self.check_tenant_created_queue() + self.check_tenant_deleted_queue() + + except KeyboardInterrupt: + logger.info("Exiting...") + finally: + self.stop_consuming() + + def check_tenant_created_queue(self): + while True: + method_frame, properties, body = self.channel.basic_get(self.tenant_created_queue_name) + if method_frame: + self.channel.basic_ack(delivery_tag=method_frame.delivery_tag) + message = json.loads(body) + tenant_id = message["tenantId"] + self.on_tenant_created(tenant_id) + else: + logger.debug("No more tenant created events.") + break + + def check_tenant_deleted_queue(self): + while True: + method_frame, properties, body = self.channel.basic_get(self.tenant_deleted_queue_name) + if method_frame: + self.channel.basic_ack(delivery_tag=method_frame.delivery_tag) + message = json.loads(body) + tenant_id = message["tenantId"] + self.on_tenant_deleted(tenant_id) + else: + logger.debug("No more tenant deleted events.") + break + + def on_tenant_created(self, tenant_id: str): + queue_name = self.service_queue_prefix + "_" + tenant_id + self.channel.queue_declare( + queue=queue_name, + durable=True, + arguments={ + "x-dead-letter-exchange": "", + "x-dead-letter-routing-key": self.service_dlq_name, + "x-expires": self.queue_expiration_time, # TODO: check if necessary + }, + ) + self.channel.queue_bind(queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id) + self.tenant_ids.append(tenant_id) + + def on_tenant_deleted(self, tenant_id: str): + queue_name = self.service_queue_prefix + "_" + tenant_id + self.channel.queue_unbind(queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id) + self.channel.queue_delete(queue_name) + self.tenant_ids.remove(tenant_id) + + def stop_consuming(self): + self._consuming = False + if self.channel and self.channel.is_open: + logger.info("Stopping consuming...") + self.channel.stop_consuming() + logger.info("Closing channel...") + self.channel.close() + + if self.connection and self.connection.is_open: + logger.info("Closing connection to RabbitMQ...") + self.connection.close() + + def publish_message_to_input_queue( + self, tenant_id: str, message: Union[str, bytes, dict], properties: pika.BasicProperties = None + ): + if isinstance(message, str): + message = message.encode("utf-8") + elif isinstance(message, dict): + message = json.dumps(message).encode("utf-8") + + self.establish_connection() + self.channel.basic_publish( + exchange=self.service_request_exchange_name, + routing_key=tenant_id, + properties=properties, + body=message, + ) + logger.info(f"Published message to queue {tenant_id}.") + + def purge_queues(self): + self.establish_connection() + try: + self.channel.queue_purge(self.tenant_created_queue_name) + self.channel.queue_purge(self.tenant_deleted_queue_name) + for tenant_id in self.tenant_ids: + queue_name = self.service_queue_prefix + "_" + tenant_id + self.channel.queue_purge(queue_name) + logger.info("Queues purged.") + except pika.exceptions.ChannelWrongStateError: + pass + + def get_message_from_output_queue(self, queue: str): + self.establish_connection() + return self.channel.basic_get(queue, auto_ack=True) + + def _make_on_message_callback(self, message_processor: MessageProcessor, tenant_id: str): + def process_message_body_and_await_result(unpacked_message_body): + # Processing the message in a separate thread is necessary for the main thread pika client to be able to + # process data events (e.g. heartbeats) while the message is being processed. + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor: + logger.info("Processing payload in separate thread.") + future = thread_pool_executor.submit(message_processor, unpacked_message_body) + + return future.result() + + def on_message_callback(channel, method, properties, body): + logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.") + + if method.redelivered: + logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.") + channel.basic_nack(method.delivery_tag, requeue=False) + return + + if body.decode("utf-8") == "STOP": + logger.info(f"Received stop signal, stopping consuming...") + channel.basic_ack(delivery_tag=method.delivery_tag) + self.stop_consuming() + return + + try: + filtered_message_headers = ( + {k: v for k, v in properties.headers.items() if k.lower().startswith("x-")} + if properties.headers + else {} + ) + logger.debug(f"Processing message with {filtered_message_headers=}.") + result: dict = ( + process_message_body_and_await_result({**json.loads(body), **filtered_message_headers}) or {} + ) + + channel.basic_publish( + exchange=self.service_response_exchange_name, + routing_key=tenant_id, + body=json.dumps(result).encode(), + properties=pika.BasicProperties(headers=filtered_message_headers), + ) + logger.info(f"Published result to queue {tenant_id}.") + + channel.basic_ack(delivery_tag=method.delivery_tag) + logger.debug(f"Message with {method.delivery_tag=} acknowledged.") + except FileNotFoundError as e: + logger.warning(f"{e}, declining message with {method.delivery_tag=}.") + channel.basic_nack(method.delivery_tag, requeue=False) + except Exception: + logger.warning(f"Failed to process message with {method.delivery_tag=}, declining...", exc_info=True) + channel.basic_nack(method.delivery_tag, requeue=False) + raise + + return on_message_callback + + def _handle_stop_signal(self, signum, *args, **kwargs): + logger.info(f"Received signal {signum}, stopping consuming...") + self.stop_consuming() + sys.exit(0) diff --git a/scripts/send_request.py b/scripts/send_request.py index d1f1fda..c7d2046 100644 --- a/scripts/send_request.py +++ b/scripts/send_request.py @@ -5,7 +5,8 @@ from operator import itemgetter from kn_utils.logging import logger from pyinfra.config.loader import load_settings, local_pyinfra_root_path -from pyinfra.queue.manager import QueueManager +# from pyinfra.queue.manager import QueueManager +from pyinfra.queue.sequential_tenants import QueueManager from pyinfra.storage.storages.s3 import get_s3_storage_from_settings settings = load_settings(local_pyinfra_root_path / "config/") @@ -41,7 +42,7 @@ def main(): message = upload_json_and_make_message_body() - queue_manager.publish_message_to_input_queue(message) + queue_manager.publish_message_to_input_queue(tenant_id="redaction", message=message) logger.info(f"Put {message} on {settings.rabbitmq.input_queue}.") storage = get_s3_storage_from_settings(settings)