feat: wip for multiple tenants

This commit is contained in:
Jonathan Kössler 2024-07-01 18:15:04 +02:00
parent 6fabe1ae8c
commit 7624208188

View File

@ -1,15 +1,20 @@
import atexit
import concurrent.futures
import pika
import os
import json
import logging
import signal
import sys
import requests
import time
import pika.exceptions
from threading import Thread
from dynaconf import Dynaconf
from typing import Callable, Union
from kn_utils.logging import logger
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
from pika.adapters.select_connection import SelectConnection
from pika.channel import Channel
from retry import retry
@ -20,20 +25,22 @@ from pyinfra.config.validators import queue_manager_validators
pika_logger = logging.getLogger("pika")
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
MessageProcessor = Callable[[dict], dict]
class BaseQueueManager:
tenant_ids = []
def __init__(self, settings: Dynaconf):
validate_settings(settings, queue_manager_validators)
self.connection_parameters = self.create_connection_parameters(settings)
self.connection: Union[BlockingConnection, None] = None
self.channel: Union[BlockingChannel, None] = None
self.connection = None
self.channel = None
self.connection_sleep = settings.rabbitmq.connection_sleep
self.queue_expiration_time = settings.rabbitmq.queue_expiration_time
self.tenant_exchange_name = settings.rabbitmq.tenant_exchange_name
tenant_ids = []
atexit.register(self.stop_consuming)
signal.signal(signal.SIGTERM, self._handle_stop_signal)
signal.signal(signal.SIGINT, self._handle_stop_signal)
@ -49,23 +56,54 @@ class BaseQueueManager:
}
return pika.ConnectionParameters(**pika_connection_params)
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger)
# @retry(tries=3, delay=5, jitter=(1, 3), logger=logger)
# def establish_connection(self):
# if self.connection and self.connection.is_open:
# logger.debug("Connection to RabbitMQ already established.")
# return
# logger.info("Establishing connection to RabbitMQ...")
# self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
# logger.debug("Opening channel...")
# self.channel = self.connection.channel()
# self.channel.basic_qos(prefetch_count=1)
# self.initialize_queues()
# logger.info("Connection to RabbitMQ established, channel open.")
def establish_connection(self):
if self.connection and self.connection.is_open:
logger.debug("Connection to RabbitMQ already established.")
return
logger.info("Establishing connection to RabbitMQ...")
self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
self.connection = SelectConnection(parameters=self.connection_parameters,
on_open_callback=self.on_connection_open,
on_open_error_callback=self.on_connection_open_error,
on_close_callback=self.on_connection_close)
def on_connection_open(self, unused_connection):
logger.debug("Connection opened")
self.connection.channel(on_open_callback=self.on_channel_open)
logger.debug("Opening channel...")
self.channel = self.connection.channel()
def on_connection_open_error(self, unused_connection, err):
logger.error(f"Connection open failed, reopening in {self.connection_sleep} seconds: {err}")
self.connection.ioloop.call_later(self.connection_sleep, self.establish_connection)
def on_connection_close(self, unused_connection, reason):
logger.warning(f"Connection closed, reopening in {self.connection_sleep} seconds: {reason}")
self.connection.ioloop.call_later(self.connection_sleep, self.establish_connection)
def on_channel_open(self, channel):
logger.debug("Channel opened")
self.channel = channel
self.channel.basic_qos(prefetch_count=1)
self.initialize_queues()
logger.info("Connection to RabbitMQ established, channel open.")
logger.info("Starting to consume messages...")
Thread(target=self.channel.start_consuming).start()
def is_ready(self):
self.establish_connection()
return self.channel.is_open
def initialize_queues(self):
raise NotImplementedError("Subclasses should implement this method")
@ -94,8 +132,11 @@ class TenantQueueManager(BaseQueueManager):
self.tenant_created_queue_name = self.get_tenant_created_queue_name(settings)
self.tenant_deleted_queue_name = self.get_tenant_deleted_queue_name(settings)
self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings)
self.event_handlers = {"tenant_created": [], "tenant_deleted": []}
self.tenant_ids = []
TenantQueueManager.tenant_ids = self.get_initial_tenant_ids(
tenant_endpoint_url=settings.storage.tenant_server.endpoint
)
def initialize_queues(self):
self.channel.exchange_declare(exchange=self.tenant_exchange_name, exchange_type="topic")
@ -134,6 +175,21 @@ class TenantQueueManager(BaseQueueManager):
self.channel.basic_consume(queue=self.tenant_created_queue_name, on_message_callback=self.on_tenant_created)
self.channel.basic_consume(queue=self.tenant_deleted_queue_name, on_message_callback=self.on_tenant_deleted)
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger, exceptions=requests.exceptions.HTTPError)
def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> list:
try:
response = requests.get(tenant_endpoint_url, timeout=10)
response.raise_for_status() # Raise an HTTPError for bad responses
if response.headers["content-type"].lower() == "application/json":
tenant_ids = [tenant["tenantId"] for tenant in response.json()]
else:
logger.warning("Response is not in JSON format.")
except Exception as e:
logger.warning("An unexpected error occurred:", e)
return tenant_ids
def get_tenant_created_queue_name(self, settings: Dynaconf):
return self.get_queue_name_with_suffix(
suffix=settings.rabbitmq.tenant_created_event_queue_suffix, pod_name=settings.kubernetes.pod_name
@ -166,9 +222,10 @@ class TenantQueueManager(BaseQueueManager):
logger.info(f"Tenant Created: {message}")
ch.basic_ack(delivery_tag=method.delivery_tag)
#TODO: replace this w/ working callback
tenant_id = body["tenant_id"]
self.tenant_ids.append(tenant_id)
# TODO: test callback
tenant_id = body["tenantId"]
TenantQueueManager.tenant_ids.append(tenant_id)
self._trigger_event("tenant_created", tenant_id)
def on_tenant_deleted(self, ch, method, properties, body):
logger.info("Received tenant deleted event")
@ -176,13 +233,38 @@ class TenantQueueManager(BaseQueueManager):
logger.info(f"Tenant Deleted: {message}")
ch.basic_ack(delivery_tag=method.delivery_tag)
#TODO: replace this w/ working callback
tenant_id = body["tenant_id"]
self.tenant_ids.remove(tenant_id)
# TODO: test callback
tenant_id = body["tenantId"]
TenantQueueManager.tenant_ids.remove(tenant_id)
self._trigger_event("tenant_deleted", tenant_id)
def _trigger_event(self, event_type, tenant_id):
handler = self.event_handlers.get(event_type)
if handler:
try:
handler(tenant_id)
except Exception as e:
logger.error(f"Error in event handler for {event_type}: {e}", exc_info=True)
def add_event_handler(self, event_type: str, handler: Callable[[str], None]):
if event_type in self.event_handlers:
self.event_handlers[event_type] = handler
else:
logger.warning(f"Unknown event type: {event_type}")
def purge_queues(self):
self.establish_connection()
try:
self.channel.queue_purge(self.tenant_created_queue_name)
self.channel.queue_purge(self.tenant_deleted_queue_name)
self.channel.queue_purge(self.tenant_events_dlq_name)
logger.info("Queues purged.")
except pika.exceptions.ChannelWrongStateError:
pass
class ServiceQueueManager(BaseQueueManager):
def __init__(self, settings: Dynaconf):
def __init__(self, settings: Dynaconf, tenant_manager: TenantQueueManager):
super().__init__(settings)
self.service_request_exchange_name = settings.rabbitmq.service_request_exchange_name
@ -190,37 +272,148 @@ class ServiceQueueManager(BaseQueueManager):
self.service_queue_prefix = settings.rabbitmq.service_request_queue_prefix
self.service_dlq_name = settings.rabbitmq.service_dlq_name
tenant_manager.add_event_handler("tenant_created", self.add_tenant_queue)
tenant_manager.add_event_handler("tenant_deleted", self.delete_tenant_queue)
def initialize_queues(self):
self.channel.exchange_declare(exchange=self.service_request_exchange_name, exchange_type="topic")
queue_name = self.service_queue_prefix + "default"
self.channel.queue_declare(queue=queue_name, arguments={"x-max-priority": 2})
self.channel.queue_bind(exchange=self.service_request_exchange_name, queue=queue_name)
self.channel.exchange_declare(exchange=self.service_request_exchange_name, exchange_type="direct")
self.channel.exchange_declare(exchange=self.service_response_exchange_name, exchange_type="direct")
for tenant_id in ServiceQueueManager.tenant_ids:
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_declare(
queue=queue_name,
durable=True,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.service_dlq_name,
"x-expires": self.queue_expiration_time, # TODO: check if necessary
"x-max-priority": 2,
},
)
self.channel.queue_bind(queue_name, self.service_request_exchange_name)
def start_consuming(self):
self.channel.queue_declare(queue=self.service_queue_prefix + "default")
for tenant_id in ServiceQueueManager.tenant_ids:
queue_name = self.service_queue_prefix + "_" + tenant_id
message_callback = self._make_on_message_callback(message_processor=MessageProcessor, tenant_id=tenant_id)
self.channel.basic_consume(
queue=queue_name,
on_message_callback=message_callback,
)
logger.info(f"Starting to consume messages for queue {queue_name}...")
self.channel.basic_consume(
queue=self.service_queue_prefix + "default",
on_message_callback=self.react_to_service_request,
auto_ack=True,
)
logger.info("Starting to consume messages...")
self.channel.start_consuming()
self.connection.ioloop.start()
def publish_message_to_input_queue(
self, tenant_id: str, message: Union[str, bytes, dict], properties: pika.BasicProperties = None
):
if isinstance(message, str):
message = message.encode("utf-8")
elif isinstance(message, dict):
message = json.dumps(message).encode("utf-8")
self.establish_connection()
self.channel.basic_publish(
exchange=self.service_request_exchange_name,
routing_key=tenant_id,
properties=properties,
body=message,
)
logger.info(f"Published message to queue {tenant_id}.")
def purge_queues(self):
self.establish_connection()
try:
for tenant_id in ServiceQueueManager.tenant_ids:
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_purge(queue_name)
logger.info("Queues purged.")
except pika.exceptions.ChannelWrongStateError:
pass
def add_tenant_queue(self, tenant_id: str):
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_declare(queue_name, durable=True)
self.channel.queue_bind(queue_name, self.service_request_exchange_name)
self.channel.queue_declare(
queue=queue_name,
durable=True,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.service_dlq_name,
"x-expires": self.queue_expiration_time, # TODO: check if necessary
},
)
self.channel.queue_bind(queue=queue_name, exchange=self.service_request_exchange_name)
# TODO: this is likely not possible due to blocking connection
message_callback = self._make_on_message_callback(message_processor=MessageProcessor, tenant_id=tenant_id)
self.channel.basic_consume(
queue=queue_name,
on_message_callback=message_callback,
)
def delete_tenant_queue(self, tenant_id: str):
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_unbind(queue_name, self.service_request_exchange_name)
self.channel.queue_delete(queue_name)
def react_to_service_request(self, ch, method, properties, body):
logger.info("Received service request")
message = json.loads(body)
logger.info(f"Service Request: {message}")
ch.basic_ack(delivery_tag=method.delivery_tag)
def _make_on_message_callback(self, message_processor: MessageProcessor, tenant_id: str):
def process_message_body_and_await_result(unpacked_message_body):
# Processing the message in a separate thread is necessary for the main thread pika client to be able to
# process data events (e.g. heartbeats) while the message is being processed.
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.info("Processing payload in separate thread.")
future = thread_pool_executor.submit(message_processor, unpacked_message_body)
# TODO: This block is probably not necessary, but kept since the implications of removing it are
# unclear. Remove it in a future iteration where less changes are being made to the code base.
# while future.running():
# logger.debug("Waiting for payload processing to finish...")
# self.connection.sleep(self.connection_sleep)
return future.result()
def on_message_callback(channel, method, properties, body):
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
if method.redelivered:
logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.")
channel.basic_nack(method.delivery_tag, requeue=False)
return
if body.decode("utf-8") == "STOP":
logger.info("Received stop signal, stopping consuming...")
channel.basic_ack(delivery_tag=method.delivery_tag)
self.stop_consuming()
return
try:
filtered_message_headers = (
{k: v for k, v in properties.headers.items() if k.lower().startswith("x-")}
if properties.headers
else {}
)
logger.debug(f"Processing message with {filtered_message_headers=}.")
result: dict = (
process_message_body_and_await_result({**json.loads(body), **filtered_message_headers}) or {}
)
channel.basic_publish(
exchange=self.service_request_exchange_name,
routing_key=tenant_id,
body=json.dumps(result).encode(),
properties=pika.BasicProperties(headers=filtered_message_headers),
)
logger.info(f"Published result to queue {tenant_id}.")
channel.basic_ack(delivery_tag=method.delivery_tag)
logger.debug(f"Message with {method.delivery_tag=} acknowledged.")
except FileNotFoundError as e:
logger.warning(f"{e}, declining message with {method.delivery_tag=}.")
channel.basic_nack(method.delivery_tag, requeue=False)
except Exception:
logger.warning(f"Failed to process message with {method.delivery_tag=}, declining...", exc_info=True)
channel.basic_nack(method.delivery_tag, requeue=False)
raise
return on_message_callback