From 6fabe1ae8cf28a753fe826784aa4ab1142446191 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20K=C3=B6ssler?= Date: Fri, 28 Jun 2024 15:41:53 +0200 Subject: [PATCH] feat: wip for multiple tenants --- pyinfra/queue/multiple_tenants.py | 226 ++++++++++++++++++++++++++++++ 1 file changed, 226 insertions(+) create mode 100644 pyinfra/queue/multiple_tenants.py diff --git a/pyinfra/queue/multiple_tenants.py b/pyinfra/queue/multiple_tenants.py new file mode 100644 index 0000000..5d0b948 --- /dev/null +++ b/pyinfra/queue/multiple_tenants.py @@ -0,0 +1,226 @@ +import atexit +import pika +import os +import json +import logging +import signal +import sys +from threading import Thread +from dynaconf import Dynaconf +from typing import Callable, Union +from kn_utils.logging import logger +from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection +from pika.channel import Channel +from retry import retry + +from pyinfra.config.loader import validate_settings +from pyinfra.config.validators import queue_manager_validators + + +pika_logger = logging.getLogger("pika") +pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter + + +class BaseQueueManager: + def __init__(self, settings: Dynaconf): + validate_settings(settings, queue_manager_validators) + + self.connection_parameters = self.create_connection_parameters(settings) + self.connection: Union[BlockingConnection, None] = None + self.channel: Union[BlockingChannel, None] = None + self.connection_sleep = settings.rabbitmq.connection_sleep + self.queue_expiration_time = settings.rabbitmq.queue_expiration_time + self.tenant_exchange_name = settings.rabbitmq.tenant_exchange_name + + tenant_ids = [] + + atexit.register(self.stop_consuming) + signal.signal(signal.SIGTERM, self._handle_stop_signal) + signal.signal(signal.SIGINT, self._handle_stop_signal) + + @staticmethod + def create_connection_parameters(settings: Dynaconf): + credentials = pika.PlainCredentials(username=settings.rabbitmq.username, password=settings.rabbitmq.password) + pika_connection_params = { + "host": settings.rabbitmq.host, + "port": settings.rabbitmq.port, + "credentials": credentials, + "heartbeat": settings.rabbitmq.heartbeat, + } + return pika.ConnectionParameters(**pika_connection_params) + + @retry(tries=3, delay=5, jitter=(1, 3), logger=logger) + def establish_connection(self): + if self.connection and self.connection.is_open: + logger.debug("Connection to RabbitMQ already established.") + return + + logger.info("Establishing connection to RabbitMQ...") + self.connection = pika.BlockingConnection(parameters=self.connection_parameters) + + logger.debug("Opening channel...") + self.channel = self.connection.channel() + self.channel.basic_qos(prefetch_count=1) + self.initialize_queues() + + logger.info("Connection to RabbitMQ established, channel open.") + logger.info("Starting to consume messages...") + Thread(target=self.channel.start_consuming).start() + + def initialize_queues(self): + raise NotImplementedError("Subclasses should implement this method") + + def stop_consuming(self): + if self.channel and self.channel.is_open: + logger.info("Stopping consuming...") + self.channel.stop_consuming() + logger.info("Closing channel...") + self.channel.close() + + if self.connection and self.connection.is_open: + logger.info("Closing connection to RabbitMQ...") + self.connection.close() + + def _handle_stop_signal(self, signum, *args, **kwargs): + logger.info(f"Received signal {signum}, stopping consuming...") + self.stop_consuming() + sys.exit(0) + + +class TenantQueueManager(BaseQueueManager): + def __init__(self, settings: Dynaconf): + super().__init__(settings) + + self.tenant_created_queue_name = self.get_tenant_created_queue_name(settings) + self.tenant_deleted_queue_name = self.get_tenant_deleted_queue_name(settings) + self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings) + + self.tenant_ids = [] + + def initialize_queues(self): + self.channel.exchange_declare(exchange=self.tenant_exchange_name, exchange_type="topic") + + self.channel.queue_declare( + queue=self.tenant_created_queue_name, + arguments={ + "x-dead-letter-exchange": "", + "x-dead-letter-routing-key": self.tenant_events_dlq_name, + "x-expires": self.queue_expiration_time, + }, + durable=True, + ) + self.channel.queue_declare( + queue=self.tenant_deleted_queue_name, + arguments={ + "x-dead-letter-exchange": "", + "x-dead-letter-routing-key": self.tenant_events_dlq_name, + "x-expires": self.queue_expiration_time, + }, + durable=True, + ) + self.channel.queue_declare( + queue=self.tenant_events_dlq_name, + arguments={"x-expires": self.queue_expiration_time}, + durable=True, + ) + + self.channel.queue_bind( + exchange=self.tenant_exchange_name, queue=self.tenant_created_queue_name, routing_key="tenant.created" + ) + self.channel.queue_bind( + exchange=self.tenant_exchange_name, queue=self.tenant_deleted_queue_name, routing_key="tenant.delete" + ) + + self.channel.basic_consume(queue=self.tenant_created_queue_name, on_message_callback=self.on_tenant_created) + self.channel.basic_consume(queue=self.tenant_deleted_queue_name, on_message_callback=self.on_tenant_deleted) + + def get_tenant_created_queue_name(self, settings: Dynaconf): + return self.get_queue_name_with_suffix( + suffix=settings.rabbitmq.tenant_created_event_queue_suffix, pod_name=settings.kubernetes.pod_name + ) + + def get_tenant_deleted_queue_name(self, settings: Dynaconf): + return self.get_queue_name_with_suffix( + suffix=settings.rabbitmq.tenant_deleted_event_queue_suffix, pod_name=settings.kubernetes.pod_name + ) + + def get_tenant_events_dlq_name(self, settings: Dynaconf): + return self.get_queue_name_with_suffix( + suffix=settings.rabbitmq.tenant_event_dlq_suffix, pod_name=settings.kubernetes.pod_name + ) + + def get_queue_name_with_suffix(self, suffix: str, pod_name: str): + if not self.use_default_queue_name() and pod_name: + return f"{pod_name}{suffix}" + return self.get_default_queue_name() + + def use_default_queue_name(self): + return False + + def get_default_queue_name(self): + raise NotImplementedError("Queue name method not implemented") + + def on_tenant_created(self, ch: Channel, method, properties, body): + logger.info("Received tenant created event") + message = json.loads(body) + logger.info(f"Tenant Created: {message}") + ch.basic_ack(delivery_tag=method.delivery_tag) + + #TODO: replace this w/ working callback + tenant_id = body["tenant_id"] + self.tenant_ids.append(tenant_id) + + def on_tenant_deleted(self, ch, method, properties, body): + logger.info("Received tenant deleted event") + message = json.loads(body) + logger.info(f"Tenant Deleted: {message}") + ch.basic_ack(delivery_tag=method.delivery_tag) + + #TODO: replace this w/ working callback + tenant_id = body["tenant_id"] + self.tenant_ids.remove(tenant_id) + + +class ServiceQueueManager(BaseQueueManager): + def __init__(self, settings: Dynaconf): + super().__init__(settings) + + self.service_request_exchange_name = settings.rabbitmq.service_request_exchange_name + self.service_response_exchange_name = settings.rabbitmq.service_response_exchange_name + self.service_queue_prefix = settings.rabbitmq.service_request_queue_prefix + self.service_dlq_name = settings.rabbitmq.service_dlq_name + + def initialize_queues(self): + self.channel.exchange_declare(exchange=self.service_request_exchange_name, exchange_type="topic") + queue_name = self.service_queue_prefix + "default" + self.channel.queue_declare(queue=queue_name, arguments={"x-max-priority": 2}) + self.channel.queue_bind(exchange=self.service_request_exchange_name, queue=queue_name) + + def start_consuming(self): + self.channel.queue_declare(queue=self.service_queue_prefix + "default") + + self.channel.basic_consume( + queue=self.service_queue_prefix + "default", + on_message_callback=self.react_to_service_request, + auto_ack=True, + ) + + logger.info("Starting to consume messages...") + self.channel.start_consuming() + + def add_tenant_queue(self, tenant_id: str): + queue_name = self.service_queue_prefix + "_" + tenant_id + self.channel.queue_declare(queue_name, durable=True) + self.channel.queue_bind(queue_name, self.service_request_exchange_name) + + def delete_tenant_queue(self, tenant_id: str): + queue_name = self.service_queue_prefix + "_" + tenant_id + self.channel.queue_unbind(queue_name, self.service_request_exchange_name) + self.channel.queue_delete(queue_name) + + def react_to_service_request(self, ch, method, properties, body): + logger.info("Received service request") + message = json.loads(body) + logger.info(f"Service Request: {message}") + ch.basic_ack(delivery_tag=method.delivery_tag) +