From c81d967aeec609365d097b0364e4dcb9044ac3a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20K=C3=B6ssler?= Date: Wed, 3 Jul 2024 17:51:47 +0200 Subject: [PATCH] feat: wip for multiple tenants --- pyinfra/examples.py | 19 +- pyinfra/queue/sequential_tenants.py | 1 + ...ultiple_tenants.py => threaded_tenants.py} | 287 +++++++----------- 3 files changed, 129 insertions(+), 178 deletions(-) rename pyinfra/queue/{multiple_tenants.py => threaded_tenants.py} (62%) diff --git a/pyinfra/examples.py b/pyinfra/examples.py index a3fb9be..8856c5b 100644 --- a/pyinfra/examples.py +++ b/pyinfra/examples.py @@ -1,11 +1,13 @@ from dynaconf import Dynaconf from fastapi import FastAPI from kn_utils.logging import logger - +import multiprocessing +from threading import Thread from pyinfra.config.loader import get_pyinfra_validators, validate_settings from pyinfra.queue.callback import Callback # from pyinfra.queue.manager import QueueManager from pyinfra.queue.sequential_tenants import QueueManager +from pyinfra.queue.threaded_tenants import ServiceQueueManager, TenantQueueManager from pyinfra.utils.opentelemetry import instrument_pika, setup_trace, instrument_app from pyinfra.webserver.prometheus import ( add_prometheus_endpoint, @@ -35,7 +37,8 @@ def start_standard_queue_consumer( app = app or FastAPI() - queue_manager = QueueManager(settings) + tenant_manager = TenantQueueManager(settings) + service_manager = ServiceQueueManager(settings) if settings.metrics.prometheus.enabled: logger.info("Prometheus metrics enabled.") @@ -48,10 +51,18 @@ def start_standard_queue_consumer( instrument_pika() instrument_app(app) - app = add_health_check_endpoint(app, queue_manager.is_ready) + # app = add_health_check_endpoint(app, queue_manager.is_ready) + app = add_health_check_endpoint(app, service_manager.is_ready) webserver_thread = create_webserver_thread_from_settings(app, settings) webserver_thread.start() # queue_manager.start_consuming(callback) - queue_manager.start_sequential_consume(callback) \ No newline at end of file + # queue_manager.start_sequential_consume(callback) + # p1 = multiprocessing.Process(target=tenant_manager.start_consuming, daemon=True) + # p2 = multiprocessing.Process(target=service_manager.start_sequential_consume, kwargs={"callback":callback}, daemon=True) + thread = Thread(target=tenant_manager.start_consuming, daemon=True) + thread.start() + # p1.start() + # p2.start() + service_manager.start_sequential_consume(callback) \ No newline at end of file diff --git a/pyinfra/queue/sequential_tenants.py b/pyinfra/queue/sequential_tenants.py index e0514a2..b1eb70f 100644 --- a/pyinfra/queue/sequential_tenants.py +++ b/pyinfra/queue/sequential_tenants.py @@ -117,6 +117,7 @@ class QueueManager: return logger.info("Establishing connection to RabbitMQ...") + logger.info(self.__class__.__name__) self.connection = pika.BlockingConnection(parameters=self.connection_parameters) logger.debug("Opening channel...") diff --git a/pyinfra/queue/multiple_tenants.py b/pyinfra/queue/threaded_tenants.py similarity index 62% rename from pyinfra/queue/multiple_tenants.py rename to pyinfra/queue/threaded_tenants.py index 156736f..6ec9327 100644 --- a/pyinfra/queue/multiple_tenants.py +++ b/pyinfra/queue/threaded_tenants.py @@ -1,17 +1,18 @@ import atexit -import asyncio import concurrent.futures import pika +import queue import json import logging import signal import sys import requests +import time import pika.exceptions from dynaconf import Dynaconf from typing import Callable, Union +from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection from kn_utils.logging import logger -from pika.adapters.select_connection import SelectConnection from pika.channel import Channel from retry import retry @@ -25,14 +26,14 @@ MessageProcessor = Callable[[dict], dict] class BaseQueueManager: - tenant_ids = [] + tenant_exchange = queue.Queue() def __init__(self, settings: Dynaconf): validate_settings(settings, queue_manager_validators) self.connection_parameters = self.create_connection_parameters(settings) - self.connection = None - self.channel = None + self.connection: Union[BlockingConnection, None] = None + self.channel: Union[BlockingChannel, None] = None self.connection_sleep = settings.rabbitmq.connection_sleep self.queue_expiration_time = settings.rabbitmq.queue_expiration_time self.tenant_exchange_name = settings.rabbitmq.tenant_exchange_name @@ -52,22 +53,6 @@ class BaseQueueManager: } return pika.ConnectionParameters(**pika_connection_params) - # @retry(tries=3, delay=5, jitter=(1, 3), logger=logger) - # def establish_connection(self): - # if self.connection and self.connection.is_open: - # logger.debug("Connection to RabbitMQ already established.") - # return - - # logger.info("Establishing connection to RabbitMQ...") - # self.connection = pika.BlockingConnection(parameters=self.connection_parameters) - - # logger.debug("Opening channel...") - # self.channel = self.connection.channel() - # self.channel.basic_qos(prefetch_count=1) - # self.initialize_queues() - - # logger.info("Connection to RabbitMQ established, channel open.") - @retry(tries=3, delay=5, jitter=(1, 3), logger=logger) def establish_connection(self): if self.connection and self.connection.is_open: @@ -75,69 +60,15 @@ class BaseQueueManager: return logger.info("Establishing connection to RabbitMQ...") - return SelectConnection(parameters=self.connection_parameters, - on_open_callback=self.on_connection_open, - on_open_error_callback=self.on_connection_open_error, - on_close_callback=self.on_connection_close) - - def close_connection(self): - # self._consuming = False - if self.connection.is_closing or self.connection.is_closed: - logger.info('Connection is closing or already closed') - else: - logger.info('Closing connection') - self.connection.close() - - def on_connection_open(self, unused_connection): - logger.debug("Connection opened") - self.open_channel() + logger.info(self.__class__.__name__) + self.connection = pika.BlockingConnection(parameters=self.connection_parameters) - def open_channel(self): - """Open a new channel with RabbitMQ by issuing the Channel.Open RPC - command. When RabbitMQ responds that the channel is open, the - on_channel_open callback will be invoked by pika. - - """ - logger.debug('Creating a new channel') - self.connection.channel(on_open_callback=self.on_channel_open) - - def on_connection_open_error(self, unused_connection, err): - logger.error(f"Connection open failed, reopening in {self.connection_sleep} seconds: {err}") - self.connection.ioloop.call_later(self.connection_sleep, self.establish_connection) - - def on_connection_close(self, unused_connection, reason): - logger.warning(f"Connection closed, reopening in {self.connection_sleep} seconds: {reason}") - self.connection.ioloop.call_later(self.connection_sleep, self.establish_connection) - - def on_channel_open(self, channel): - logger.debug("Channel opened") - self.channel = channel - # self.add_on_channel_close_callback() + logger.debug("Opening channel...") + self.channel = self.connection.channel() self.channel.basic_qos(prefetch_count=1) self.initialize_queues() - - # def add_on_channel_close_callback(self): - # """This method tells pika to call the on_channel_closed method if - # RabbitMQ unexpectedly closes the channel. - - # """ - # logger.debug('Adding channel close callback') - # self.channel.add_on_close_callback(self.on_channel_closed) - - # def on_channel_closed(self, channel, reason): - # """Invoked by pika when RabbitMQ unexpectedly closes the channel. - # Channels are usually closed if you attempt to do something that - # violates the protocol, such as re-declare an exchange or queue with - # different parameters. In this case, we'll close the connection - # to shutdown the object. - - # :param pika.channel.Channel: The closed channel - # :param Exception reason: why the channel was closed - - # """ - # logger.warning('Channel %i was closed: %s', channel, reason) - # self.close_connection() + logger.info("Connection to RabbitMQ established, channel open.") def is_ready(self): self.establish_connection() @@ -170,11 +101,6 @@ class TenantQueueManager(BaseQueueManager): self.tenant_created_queue_name = self.get_tenant_created_queue_name(settings) self.tenant_deleted_queue_name = self.get_tenant_deleted_queue_name(settings) self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings) - self.event_handlers = {"tenant_created": [], "tenant_deleted": []} - - self.get_initial_tenant_ids( - tenant_endpoint_url=settings.storage.tenant_server.endpoint - ) def initialize_queues(self): self.channel.exchange_declare(exchange=self.tenant_exchange_name, exchange_type="topic") @@ -206,29 +132,20 @@ class TenantQueueManager(BaseQueueManager): self.channel.queue_bind( exchange=self.tenant_exchange_name, queue=self.tenant_deleted_queue_name, routing_key="tenant.delete" ) + + @retry(exceptions=pika.exceptions.AMQPConnectionError, tries=3, delay=5, jitter=(1, 3), logger=logger) + def start_consuming(self): - self.channel.basic_consume(queue=self.tenant_created_queue_name, on_message_callback=self.on_tenant_created) - self.channel.basic_consume(queue=self.tenant_deleted_queue_name, on_message_callback=self.on_tenant_deleted) - - def start(self): - self.connection = self.establish_connection() - if self.connection is not None: - self.connection.ioloop.start() - - @retry(tries=3, delay=5, jitter=(1, 3), logger=logger, exceptions=requests.exceptions.HTTPError) - def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> list: try: - response = requests.get(tenant_endpoint_url, timeout=10) - response.raise_for_status() # Raise an HTTPError for bad responses - - if response.headers["content-type"].lower() == "application/json": - tenants = [tenant["tenantId"] for tenant in response.json()] - else: - logger.warning("Response is not in JSON format.") - except Exception as e: - logger.warning("An unexpected error occurred:", e) - - self.tenant_ids.extend(tenants) + self.establish_connection() + self.channel.basic_consume(queue=self.tenant_created_queue_name, on_message_callback=self.on_tenant_created) + self.channel.basic_consume(queue=self.tenant_deleted_queue_name, on_message_callback=self.on_tenant_deleted) + self.channel.start_consuming() + except Exception: + logger.error("An unexpected error occurred while consuming messages. Consuming will stop.", exc_info=True) + raise + finally: + self.stop_consuming() def get_tenant_created_queue_name(self, settings: Dynaconf): return self.get_queue_name_with_suffix( @@ -259,52 +176,31 @@ class TenantQueueManager(BaseQueueManager): def on_tenant_created(self, ch: Channel, method, properties, body): logger.info("Received tenant created event") message = json.loads(body) - logger.info(f"Tenant Created: {message}") ch.basic_ack(delivery_tag=method.delivery_tag) - # TODO: test callback - tenant_id = body["tenantId"] - self.tenant_ids.append(tenant_id) - self._trigger_event("tenant_created", tenant_id) + tenant_id = message["tenantId"] + self.tenant_exchange.put(("create", tenant_id)) def on_tenant_deleted(self, ch: Channel, method, properties, body): logger.info("Received tenant deleted event") message = json.loads(body) - logger.info(f"Tenant Deleted: {message}") ch.basic_ack(delivery_tag=method.delivery_tag) - - # TODO: test callback - tenant_id = body["tenantId"] - self.tenant_ids.remove(tenant_id) - self._trigger_event("tenant_deleted", tenant_id) - - def _trigger_event(self, event_type, tenant_id): - handler = self.event_handlers.get(event_type) - if handler: - try: - handler(tenant_id) - except Exception as e: - logger.error(f"Error in event handler for {event_type}: {e}", exc_info=True) - - def add_event_handler(self, event_type: str, handler: Callable[[str], None]): - if event_type in self.event_handlers: - self.event_handlers[event_type] = handler - else: - logger.warning(f"Unknown event type: {event_type}") + + tenant_id = message["tenantId"] + self.tenant_exchange.put(("delete", tenant_id)) def purge_queues(self): self.establish_connection() try: self.channel.queue_purge(self.tenant_created_queue_name) self.channel.queue_purge(self.tenant_deleted_queue_name) - self.channel.queue_purge(self.tenant_events_dlq_name) logger.info("Queues purged.") except pika.exceptions.ChannelWrongStateError: pass class ServiceQueueManager(BaseQueueManager): - def __init__(self, settings: Dynaconf, tenant_manager: TenantQueueManager): + def __init__(self, settings: Dynaconf): super().__init__(settings) self.service_request_exchange_name = settings.rabbitmq.service_request_exchange_name @@ -312,8 +208,9 @@ class ServiceQueueManager(BaseQueueManager): self.service_queue_prefix = settings.rabbitmq.service_request_queue_prefix self.service_dlq_name = settings.rabbitmq.service_dlq_name - tenant_manager.add_event_handler("tenant_created", self.add_tenant_queue) - tenant_manager.add_event_handler("tenant_deleted", self.delete_tenant_queue) + self.tenant_ids = self.get_initial_tenant_ids(tenant_endpoint_url=settings.storage.tenant_server.endpoint) + + self._consuming = False def initialize_queues(self): self.channel.exchange_declare(exchange=self.service_request_exchange_name, exchange_type="direct") @@ -331,24 +228,65 @@ class ServiceQueueManager(BaseQueueManager): "x-max-priority": 2, }, ) - self.channel.queue_bind(queue_name, self.service_request_exchange_name) - - def start_consuming(self, message_processor: Callable): - self.connection = self.establish_connection() - for tenant_id in self.tenant_ids: - queue_name = self.service_queue_prefix + "_" + tenant_id - message_callback = self._make_on_message_callback(message_processor=message_processor, tenant_id=tenant_id) - self.channel.basic_consume( - queue=queue_name, - on_message_callback=message_callback, + self.channel.queue_bind( + queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id ) - logger.info(f"Starting to consume messages for queue {queue_name}...") - # self.channel.start_consuming() - if self.connection is not None: - self.connection.ioloop.start() - else: - logger.info("Connection is None, cannot start ioloop") + @retry(tries=3, delay=5, jitter=(1, 3), logger=logger, exceptions=requests.exceptions.HTTPError) + def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> list: + try: + response = requests.get(tenant_endpoint_url, timeout=10) + response.raise_for_status() # Raise an HTTPError for bad responses + + if response.headers["content-type"].lower() == "application/json": + tenants = [tenant["tenantId"] for tenant in response.json()] + else: + logger.warning("Response is not in JSON format.") + except Exception as e: + logger.warning("An unexpected error occurred:", e) + + return tenants + + @retry(exceptions=pika.exceptions.AMQPConnectionError, tries=3, delay=5, jitter=(1, 3), logger=logger) + def start_sequential_consume(self, message_processor: Callable): + + self.establish_connection() + self._consuming = True + + try: + while self._consuming: + for tenant_id in self.tenant_ids: + queue_name = self.service_queue_prefix + "_" + tenant_id + method_frame, properties, body = self.channel.basic_get(queue_name) + if method_frame: + on_message_callback = self._make_on_message_callback(message_processor, tenant_id) + on_message_callback(self.channel, method_frame, properties, body) + else: + logger.debug("No message returned") + time.sleep(self.connection_sleep) + + ### Handle tenant events + self.check_tenant_exchange() + + except KeyboardInterrupt: + logger.info("Exiting...") + finally: + self.stop_consuming() + + def check_tenant_exchange(self): + while True: + try: + event, tenant = self.tenant_exchange.get(block=False) + if event == "create": + self.on_tenant_created(tenant) + elif event == "delete": + self.on_tenant_deleted(tenant) + else: + break + except Exception: + logger.debug("No tenant exchange events.") + break + def publish_message_to_input_queue( self, tenant_id: str, message: Union[str, bytes, dict], properties: pika.BasicProperties = None @@ -377,7 +315,7 @@ class ServiceQueueManager(BaseQueueManager): except pika.exceptions.ChannelWrongStateError: pass - def add_tenant_queue(self, tenant_id: str): + def on_tenant_created(self, tenant_id: str): queue_name = self.service_queue_prefix + "_" + tenant_id self.channel.queue_declare( queue=queue_name, @@ -388,18 +326,16 @@ class ServiceQueueManager(BaseQueueManager): "x-expires": self.queue_expiration_time, # TODO: check if necessary }, ) - self.channel.queue_bind(queue=queue_name, exchange=self.service_request_exchange_name) - # TODO: this is likely not possible due to blocking connection - message_callback = self._make_on_message_callback(message_processor=MessageProcessor, tenant_id=tenant_id) - self.channel.basic_consume( - queue=queue_name, - on_message_callback=message_callback, - ) + self.channel.queue_bind(queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id) + self.tenant_ids.append(tenant_id) + logger.debug(f"Added tenant {tenant_id}.") - def delete_tenant_queue(self, tenant_id: str): + def on_tenant_deleted(self, tenant_id: str): queue_name = self.service_queue_prefix + "_" + tenant_id - self.channel.queue_unbind(queue_name, self.service_request_exchange_name) + self.channel.queue_unbind(queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id) self.channel.queue_delete(queue_name) + self.tenant_ids.remove(tenant_id) + logger.debug(f"Deleted tenant {tenant_id}.") def _make_on_message_callback(self, message_processor: MessageProcessor, tenant_id: str): def process_message_body_and_await_result(unpacked_message_body): @@ -409,16 +345,7 @@ class ServiceQueueManager(BaseQueueManager): logger.info("Processing payload in separate thread.") future = thread_pool_executor.submit(message_processor, unpacked_message_body) - # TODO: This block is probably not necessary, but kept since the implications of removing it are - # unclear. Remove it in a future iteration where less changes are being made to the code base. - # while future.running(): - # logger.debug("Waiting for payload processing to finish...") - # self.connection.sleep(self.connection_sleep) - - loop = asyncio.get_event_loop() - return loop.run_in_executor(None, future.result) - - # return future.result() + return future.result() def on_message_callback(channel, method, properties, body): logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.") @@ -429,7 +356,7 @@ class ServiceQueueManager(BaseQueueManager): return if body.decode("utf-8") == "STOP": - logger.info("Received stop signal, stopping consuming...") + logger.info(f"Received stop signal, stopping consuming...") channel.basic_ack(delivery_tag=method.delivery_tag) self.stop_consuming() return @@ -446,7 +373,7 @@ class ServiceQueueManager(BaseQueueManager): ) channel.basic_publish( - exchange=self.service_request_exchange_name, + exchange=self.service_response_exchange_name, routing_key=tenant_id, body=json.dumps(result).encode(), properties=pika.BasicProperties(headers=filtered_message_headers), @@ -463,4 +390,16 @@ class ServiceQueueManager(BaseQueueManager): channel.basic_nack(method.delivery_tag, requeue=False) raise - return on_message_callback \ No newline at end of file + return on_message_callback + + def stop_consuming(self): + self._consuming = False + if self.channel and self.channel.is_open: + logger.info("Stopping consuming...") + self.channel.stop_consuming() + logger.info("Closing channel...") + self.channel.close() + + if self.connection and self.connection.is_open: + logger.info("Closing connection to RabbitMQ...") + self.connection.close() \ No newline at end of file