pyinfra/pyinfra/queue/threaded_tenants.py
2024-07-03 17:51:47 +02:00

405 lines
16 KiB
Python

import atexit
import concurrent.futures
import pika
import queue
import json
import logging
import signal
import sys
import requests
import time
import pika.exceptions
from dynaconf import Dynaconf
from typing import Callable, Union
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
from kn_utils.logging import logger
from pika.channel import Channel
from retry import retry
from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import queue_manager_validators
pika_logger = logging.getLogger("pika")
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
MessageProcessor = Callable[[dict], dict]
class BaseQueueManager:
tenant_exchange = queue.Queue()
def __init__(self, settings: Dynaconf):
validate_settings(settings, queue_manager_validators)
self.connection_parameters = self.create_connection_parameters(settings)
self.connection: Union[BlockingConnection, None] = None
self.channel: Union[BlockingChannel, None] = None
self.connection_sleep = settings.rabbitmq.connection_sleep
self.queue_expiration_time = settings.rabbitmq.queue_expiration_time
self.tenant_exchange_name = settings.rabbitmq.tenant_exchange_name
atexit.register(self.stop_consuming)
signal.signal(signal.SIGTERM, self._handle_stop_signal)
signal.signal(signal.SIGINT, self._handle_stop_signal)
@staticmethod
def create_connection_parameters(settings: Dynaconf):
credentials = pika.PlainCredentials(username=settings.rabbitmq.username, password=settings.rabbitmq.password)
pika_connection_params = {
"host": settings.rabbitmq.host,
"port": settings.rabbitmq.port,
"credentials": credentials,
"heartbeat": settings.rabbitmq.heartbeat,
}
return pika.ConnectionParameters(**pika_connection_params)
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger)
def establish_connection(self):
if self.connection and self.connection.is_open:
logger.debug("Connection to RabbitMQ already established.")
return
logger.info("Establishing connection to RabbitMQ...")
logger.info(self.__class__.__name__)
self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
logger.debug("Opening channel...")
self.channel = self.connection.channel()
self.channel.basic_qos(prefetch_count=1)
self.initialize_queues()
logger.info("Connection to RabbitMQ established, channel open.")
def is_ready(self):
self.establish_connection()
return self.channel.is_open
def initialize_queues(self):
raise NotImplementedError("Subclasses should implement this method")
def stop_consuming(self):
if self.channel and self.channel.is_open:
logger.info("Stopping consuming...")
self.channel.stop_consuming()
logger.info("Closing channel...")
self.channel.close()
if self.connection and self.connection.is_open:
logger.info("Closing connection to RabbitMQ...")
self.connection.close()
def _handle_stop_signal(self, signum, *args, **kwargs):
logger.info(f"Received signal {signum}, stopping consuming...")
self.stop_consuming()
sys.exit(0)
class TenantQueueManager(BaseQueueManager):
def __init__(self, settings: Dynaconf):
super().__init__(settings)
self.tenant_created_queue_name = self.get_tenant_created_queue_name(settings)
self.tenant_deleted_queue_name = self.get_tenant_deleted_queue_name(settings)
self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings)
def initialize_queues(self):
self.channel.exchange_declare(exchange=self.tenant_exchange_name, exchange_type="topic")
self.channel.queue_declare(
queue=self.tenant_created_queue_name,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.tenant_events_dlq_name,
},
durable=True,
)
self.channel.queue_declare(
queue=self.tenant_deleted_queue_name,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.tenant_events_dlq_name,
},
durable=True,
)
self.channel.queue_declare(
queue=self.tenant_events_dlq_name,
durable=True,
)
self.channel.queue_bind(
exchange=self.tenant_exchange_name, queue=self.tenant_created_queue_name, routing_key="tenant.created"
)
self.channel.queue_bind(
exchange=self.tenant_exchange_name, queue=self.tenant_deleted_queue_name, routing_key="tenant.delete"
)
@retry(exceptions=pika.exceptions.AMQPConnectionError, tries=3, delay=5, jitter=(1, 3), logger=logger)
def start_consuming(self):
try:
self.establish_connection()
self.channel.basic_consume(queue=self.tenant_created_queue_name, on_message_callback=self.on_tenant_created)
self.channel.basic_consume(queue=self.tenant_deleted_queue_name, on_message_callback=self.on_tenant_deleted)
self.channel.start_consuming()
except Exception:
logger.error("An unexpected error occurred while consuming messages. Consuming will stop.", exc_info=True)
raise
finally:
self.stop_consuming()
def get_tenant_created_queue_name(self, settings: Dynaconf):
return self.get_queue_name_with_suffix(
suffix=settings.rabbitmq.tenant_created_event_queue_suffix, pod_name=settings.kubernetes.pod_name
)
def get_tenant_deleted_queue_name(self, settings: Dynaconf):
return self.get_queue_name_with_suffix(
suffix=settings.rabbitmq.tenant_deleted_event_queue_suffix, pod_name=settings.kubernetes.pod_name
)
def get_tenant_events_dlq_name(self, settings: Dynaconf):
return self.get_queue_name_with_suffix(
suffix=settings.rabbitmq.tenant_event_dlq_suffix, pod_name=settings.kubernetes.pod_name
)
def get_queue_name_with_suffix(self, suffix: str, pod_name: str):
if not self.use_default_queue_name() and pod_name:
return f"{pod_name}{suffix}"
return self.get_default_queue_name()
def use_default_queue_name(self):
return False
def get_default_queue_name(self):
raise NotImplementedError("Queue name method not implemented")
def on_tenant_created(self, ch: Channel, method, properties, body):
logger.info("Received tenant created event")
message = json.loads(body)
ch.basic_ack(delivery_tag=method.delivery_tag)
tenant_id = message["tenantId"]
self.tenant_exchange.put(("create", tenant_id))
def on_tenant_deleted(self, ch: Channel, method, properties, body):
logger.info("Received tenant deleted event")
message = json.loads(body)
ch.basic_ack(delivery_tag=method.delivery_tag)
tenant_id = message["tenantId"]
self.tenant_exchange.put(("delete", tenant_id))
def purge_queues(self):
self.establish_connection()
try:
self.channel.queue_purge(self.tenant_created_queue_name)
self.channel.queue_purge(self.tenant_deleted_queue_name)
logger.info("Queues purged.")
except pika.exceptions.ChannelWrongStateError:
pass
class ServiceQueueManager(BaseQueueManager):
def __init__(self, settings: Dynaconf):
super().__init__(settings)
self.service_request_exchange_name = settings.rabbitmq.service_request_exchange_name
self.service_response_exchange_name = settings.rabbitmq.service_response_exchange_name
self.service_queue_prefix = settings.rabbitmq.service_request_queue_prefix
self.service_dlq_name = settings.rabbitmq.service_dlq_name
self.tenant_ids = self.get_initial_tenant_ids(tenant_endpoint_url=settings.storage.tenant_server.endpoint)
self._consuming = False
def initialize_queues(self):
self.channel.exchange_declare(exchange=self.service_request_exchange_name, exchange_type="direct")
self.channel.exchange_declare(exchange=self.service_response_exchange_name, exchange_type="direct")
for tenant_id in self.tenant_ids:
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_declare(
queue=queue_name,
durable=True,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.service_dlq_name,
"x-expires": self.queue_expiration_time, # TODO: check if necessary
"x-max-priority": 2,
},
)
self.channel.queue_bind(
queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id
)
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger, exceptions=requests.exceptions.HTTPError)
def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> list:
try:
response = requests.get(tenant_endpoint_url, timeout=10)
response.raise_for_status() # Raise an HTTPError for bad responses
if response.headers["content-type"].lower() == "application/json":
tenants = [tenant["tenantId"] for tenant in response.json()]
else:
logger.warning("Response is not in JSON format.")
except Exception as e:
logger.warning("An unexpected error occurred:", e)
return tenants
@retry(exceptions=pika.exceptions.AMQPConnectionError, tries=3, delay=5, jitter=(1, 3), logger=logger)
def start_sequential_consume(self, message_processor: Callable):
self.establish_connection()
self._consuming = True
try:
while self._consuming:
for tenant_id in self.tenant_ids:
queue_name = self.service_queue_prefix + "_" + tenant_id
method_frame, properties, body = self.channel.basic_get(queue_name)
if method_frame:
on_message_callback = self._make_on_message_callback(message_processor, tenant_id)
on_message_callback(self.channel, method_frame, properties, body)
else:
logger.debug("No message returned")
time.sleep(self.connection_sleep)
### Handle tenant events
self.check_tenant_exchange()
except KeyboardInterrupt:
logger.info("Exiting...")
finally:
self.stop_consuming()
def check_tenant_exchange(self):
while True:
try:
event, tenant = self.tenant_exchange.get(block=False)
if event == "create":
self.on_tenant_created(tenant)
elif event == "delete":
self.on_tenant_deleted(tenant)
else:
break
except Exception:
logger.debug("No tenant exchange events.")
break
def publish_message_to_input_queue(
self, tenant_id: str, message: Union[str, bytes, dict], properties: pika.BasicProperties = None
):
if isinstance(message, str):
message = message.encode("utf-8")
elif isinstance(message, dict):
message = json.dumps(message).encode("utf-8")
self.establish_connection()
self.channel.basic_publish(
exchange=self.service_request_exchange_name,
routing_key=tenant_id,
properties=properties,
body=message,
)
logger.info(f"Published message to queue {tenant_id}.")
def purge_queues(self):
self.establish_connection()
try:
for tenant_id in self.tenant_ids:
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_purge(queue_name)
logger.info("Queues purged.")
except pika.exceptions.ChannelWrongStateError:
pass
def on_tenant_created(self, tenant_id: str):
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_declare(
queue=queue_name,
durable=True,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.service_dlq_name,
"x-expires": self.queue_expiration_time, # TODO: check if necessary
},
)
self.channel.queue_bind(queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id)
self.tenant_ids.append(tenant_id)
logger.debug(f"Added tenant {tenant_id}.")
def on_tenant_deleted(self, tenant_id: str):
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_unbind(queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id)
self.channel.queue_delete(queue_name)
self.tenant_ids.remove(tenant_id)
logger.debug(f"Deleted tenant {tenant_id}.")
def _make_on_message_callback(self, message_processor: MessageProcessor, tenant_id: str):
def process_message_body_and_await_result(unpacked_message_body):
# Processing the message in a separate thread is necessary for the main thread pika client to be able to
# process data events (e.g. heartbeats) while the message is being processed.
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.info("Processing payload in separate thread.")
future = thread_pool_executor.submit(message_processor, unpacked_message_body)
return future.result()
def on_message_callback(channel, method, properties, body):
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
if method.redelivered:
logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.")
channel.basic_nack(method.delivery_tag, requeue=False)
return
if body.decode("utf-8") == "STOP":
logger.info(f"Received stop signal, stopping consuming...")
channel.basic_ack(delivery_tag=method.delivery_tag)
self.stop_consuming()
return
try:
filtered_message_headers = (
{k: v for k, v in properties.headers.items() if k.lower().startswith("x-")}
if properties.headers
else {}
)
logger.debug(f"Processing message with {filtered_message_headers=}.")
result: dict = (
process_message_body_and_await_result({**json.loads(body), **filtered_message_headers}) or {}
)
channel.basic_publish(
exchange=self.service_response_exchange_name,
routing_key=tenant_id,
body=json.dumps(result).encode(),
properties=pika.BasicProperties(headers=filtered_message_headers),
)
logger.info(f"Published result to queue {tenant_id}.")
channel.basic_ack(delivery_tag=method.delivery_tag)
logger.debug(f"Message with {method.delivery_tag=} acknowledged.")
except FileNotFoundError as e:
logger.warning(f"{e}, declining message with {method.delivery_tag=}.")
channel.basic_nack(method.delivery_tag, requeue=False)
except Exception:
logger.warning(f"Failed to process message with {method.delivery_tag=}, declining...", exc_info=True)
channel.basic_nack(method.delivery_tag, requeue=False)
raise
return on_message_callback
def stop_consuming(self):
self._consuming = False
if self.channel and self.channel.is_open:
logger.info("Stopping consuming...")
self.channel.stop_consuming()
logger.info("Closing channel...")
self.channel.close()
if self.connection and self.connection.is_open:
logger.info("Closing connection to RabbitMQ...")
self.connection.close()