feat: wip for multiple tenants

This commit is contained in:
Jonathan Kössler 2024-07-02 18:07:23 +02:00
parent 7624208188
commit 30330937ce
4 changed files with 420 additions and 28 deletions

View File

@ -4,7 +4,8 @@ from kn_utils.logging import logger
from pyinfra.config.loader import get_pyinfra_validators, validate_settings
from pyinfra.queue.callback import Callback
from pyinfra.queue.manager import QueueManager
# from pyinfra.queue.manager import QueueManager
from pyinfra.queue.sequential_tenants import QueueManager
from pyinfra.utils.opentelemetry import instrument_pika, setup_trace, instrument_app
from pyinfra.webserver.prometheus import (
add_prometheus_endpoint,
@ -52,4 +53,5 @@ def start_standard_queue_consumer(
webserver_thread = create_webserver_thread_from_settings(app, settings)
webserver_thread.start()
queue_manager.start_consuming(callback)
# queue_manager.start_consuming(callback)
queue_manager.start_sequential_consume(callback)

View File

@ -1,19 +1,16 @@
import atexit
import asyncio
import concurrent.futures
import pika
import os
import json
import logging
import signal
import sys
import requests
import time
import pika.exceptions
from threading import Thread
from dynaconf import Dynaconf
from typing import Callable, Union
from kn_utils.logging import logger
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
from pika.adapters.select_connection import SelectConnection
from pika.channel import Channel
from retry import retry
@ -21,7 +18,6 @@ from retry import retry
from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import queue_manager_validators
pika_logger = logging.getLogger("pika")
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
@ -72,19 +68,37 @@ class BaseQueueManager:
# logger.info("Connection to RabbitMQ established, channel open.")
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger)
def establish_connection(self):
if self.connection and self.connection.is_open:
logger.debug("Connection to RabbitMQ already established.")
return
logger.info("Establishing connection to RabbitMQ...")
self.connection = SelectConnection(parameters=self.connection_parameters,
return SelectConnection(parameters=self.connection_parameters,
on_open_callback=self.on_connection_open,
on_open_error_callback=self.on_connection_open_error,
on_close_callback=self.on_connection_close)
def close_connection(self):
# self._consuming = False
if self.connection.is_closing or self.connection.is_closed:
logger.info('Connection is closing or already closed')
else:
logger.info('Closing connection')
self.connection.close()
def on_connection_open(self, unused_connection):
logger.debug("Connection opened")
self.open_channel()
def open_channel(self):
"""Open a new channel with RabbitMQ by issuing the Channel.Open RPC
command. When RabbitMQ responds that the channel is open, the
on_channel_open callback will be invoked by pika.
"""
logger.debug('Creating a new channel')
self.connection.channel(on_open_callback=self.on_channel_open)
def on_connection_open_error(self, unused_connection, err):
@ -98,9 +112,33 @@ class BaseQueueManager:
def on_channel_open(self, channel):
logger.debug("Channel opened")
self.channel = channel
# self.add_on_channel_close_callback()
self.channel.basic_qos(prefetch_count=1)
self.initialize_queues()
# def add_on_channel_close_callback(self):
# """This method tells pika to call the on_channel_closed method if
# RabbitMQ unexpectedly closes the channel.
# """
# logger.debug('Adding channel close callback')
# self.channel.add_on_close_callback(self.on_channel_closed)
# def on_channel_closed(self, channel, reason):
# """Invoked by pika when RabbitMQ unexpectedly closes the channel.
# Channels are usually closed if you attempt to do something that
# violates the protocol, such as re-declare an exchange or queue with
# different parameters. In this case, we'll close the connection
# to shutdown the object.
# :param pika.channel.Channel: The closed channel
# :param Exception reason: why the channel was closed
# """
# logger.warning('Channel %i was closed: %s', channel, reason)
# self.close_connection()
def is_ready(self):
self.establish_connection()
return self.channel.is_open
@ -134,7 +172,7 @@ class TenantQueueManager(BaseQueueManager):
self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings)
self.event_handlers = {"tenant_created": [], "tenant_deleted": []}
TenantQueueManager.tenant_ids = self.get_initial_tenant_ids(
self.get_initial_tenant_ids(
tenant_endpoint_url=settings.storage.tenant_server.endpoint
)
@ -146,7 +184,6 @@ class TenantQueueManager(BaseQueueManager):
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.tenant_events_dlq_name,
"x-expires": self.queue_expiration_time,
},
durable=True,
)
@ -155,13 +192,11 @@ class TenantQueueManager(BaseQueueManager):
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.tenant_events_dlq_name,
"x-expires": self.queue_expiration_time,
},
durable=True,
)
self.channel.queue_declare(
queue=self.tenant_events_dlq_name,
arguments={"x-expires": self.queue_expiration_time},
durable=True,
)
@ -175,6 +210,11 @@ class TenantQueueManager(BaseQueueManager):
self.channel.basic_consume(queue=self.tenant_created_queue_name, on_message_callback=self.on_tenant_created)
self.channel.basic_consume(queue=self.tenant_deleted_queue_name, on_message_callback=self.on_tenant_deleted)
def start(self):
self.connection = self.establish_connection()
if self.connection is not None:
self.connection.ioloop.start()
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger, exceptions=requests.exceptions.HTTPError)
def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> list:
try:
@ -182,13 +222,13 @@ class TenantQueueManager(BaseQueueManager):
response.raise_for_status() # Raise an HTTPError for bad responses
if response.headers["content-type"].lower() == "application/json":
tenant_ids = [tenant["tenantId"] for tenant in response.json()]
tenants = [tenant["tenantId"] for tenant in response.json()]
else:
logger.warning("Response is not in JSON format.")
except Exception as e:
logger.warning("An unexpected error occurred:", e)
return tenant_ids
self.tenant_ids.extend(tenants)
def get_tenant_created_queue_name(self, settings: Dynaconf):
return self.get_queue_name_with_suffix(
@ -224,10 +264,10 @@ class TenantQueueManager(BaseQueueManager):
# TODO: test callback
tenant_id = body["tenantId"]
TenantQueueManager.tenant_ids.append(tenant_id)
self.tenant_ids.append(tenant_id)
self._trigger_event("tenant_created", tenant_id)
def on_tenant_deleted(self, ch, method, properties, body):
def on_tenant_deleted(self, ch: Channel, method, properties, body):
logger.info("Received tenant deleted event")
message = json.loads(body)
logger.info(f"Tenant Deleted: {message}")
@ -235,7 +275,7 @@ class TenantQueueManager(BaseQueueManager):
# TODO: test callback
tenant_id = body["tenantId"]
TenantQueueManager.tenant_ids.remove(tenant_id)
self.tenant_ids.remove(tenant_id)
self._trigger_event("tenant_deleted", tenant_id)
def _trigger_event(self, event_type, tenant_id):
@ -279,7 +319,7 @@ class ServiceQueueManager(BaseQueueManager):
self.channel.exchange_declare(exchange=self.service_request_exchange_name, exchange_type="direct")
self.channel.exchange_declare(exchange=self.service_response_exchange_name, exchange_type="direct")
for tenant_id in ServiceQueueManager.tenant_ids:
for tenant_id in self.tenant_ids:
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_declare(
queue=queue_name,
@ -293,18 +333,22 @@ class ServiceQueueManager(BaseQueueManager):
)
self.channel.queue_bind(queue_name, self.service_request_exchange_name)
def start_consuming(self):
for tenant_id in ServiceQueueManager.tenant_ids:
def start_consuming(self, message_processor: Callable):
self.connection = self.establish_connection()
for tenant_id in self.tenant_ids:
queue_name = self.service_queue_prefix + "_" + tenant_id
message_callback = self._make_on_message_callback(message_processor=MessageProcessor, tenant_id=tenant_id)
message_callback = self._make_on_message_callback(message_processor=message_processor, tenant_id=tenant_id)
self.channel.basic_consume(
queue=queue_name,
on_message_callback=message_callback,
)
logger.info(f"Starting to consume messages for queue {queue_name}...")
self.channel.start_consuming()
self.connection.ioloop.start()
# self.channel.start_consuming()
if self.connection is not None:
self.connection.ioloop.start()
else:
logger.info("Connection is None, cannot start ioloop")
def publish_message_to_input_queue(
self, tenant_id: str, message: Union[str, bytes, dict], properties: pika.BasicProperties = None
@ -326,7 +370,7 @@ class ServiceQueueManager(BaseQueueManager):
def purge_queues(self):
self.establish_connection()
try:
for tenant_id in ServiceQueueManager.tenant_ids:
for tenant_id in self.tenant_ids:
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_purge(queue_name)
logger.info("Queues purged.")
@ -371,7 +415,10 @@ class ServiceQueueManager(BaseQueueManager):
# logger.debug("Waiting for payload processing to finish...")
# self.connection.sleep(self.connection_sleep)
return future.result()
loop = asyncio.get_event_loop()
return loop.run_in_executor(None, future.result)
# return future.result()
def on_message_callback(channel, method, properties, body):
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
@ -416,4 +463,4 @@ class ServiceQueueManager(BaseQueueManager):
channel.basic_nack(method.delivery_tag, requeue=False)
raise
return on_message_callback
return on_message_callback

View File

@ -0,0 +1,342 @@
import atexit
import concurrent.futures
import json
import logging
import requests
import signal
import sys
import time
from typing import Callable, Union
import pika
import pika.exceptions
from dynaconf import Dynaconf
from kn_utils.logging import logger
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
from retry import retry
from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import queue_manager_validators
logger.set_level("DEBUG")
pika_logger = logging.getLogger("pika")
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
MessageProcessor = Callable[[dict], dict]
class QueueManager:
def __init__(self, settings: Dynaconf):
validate_settings(settings, queue_manager_validators)
self.tenant_created_queue_name = self.get_tenant_created_queue_name(settings)
self.tenant_deleted_queue_name = self.get_tenant_deleted_queue_name(settings)
self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings)
self.connection_sleep = settings.rabbitmq.connection_sleep
self.queue_expiration_time = settings.rabbitmq.queue_expiration_time
self.tenant_exchange_name = settings.rabbitmq.tenant_exchange_name
self.service_request_exchange_name = settings.rabbitmq.service_request_exchange_name
self.service_response_exchange_name = settings.rabbitmq.service_response_exchange_name
self.service_queue_prefix = settings.rabbitmq.service_request_queue_prefix
self.service_dlq_name = settings.rabbitmq.service_dlq_name
self.connection_parameters = self.create_connection_parameters(settings)
self.connection: Union[BlockingConnection, None] = None
self.channel: Union[BlockingChannel, None] = None
self.tenant_ids = self.get_initial_tenant_ids(tenant_endpoint_url=settings.storage.tenant_server.endpoint)
self._consuming = False
atexit.register(self.stop_consuming)
signal.signal(signal.SIGTERM, self._handle_stop_signal)
signal.signal(signal.SIGINT, self._handle_stop_signal)
@staticmethod
def create_connection_parameters(settings: Dynaconf):
credentials = pika.PlainCredentials(username=settings.rabbitmq.username, password=settings.rabbitmq.password)
pika_connection_params = {
"host": settings.rabbitmq.host,
"port": settings.rabbitmq.port,
"credentials": credentials,
"heartbeat": settings.rabbitmq.heartbeat,
}
return pika.ConnectionParameters(**pika_connection_params)
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger, exceptions=requests.exceptions.HTTPError)
def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> list:
try:
response = requests.get(tenant_endpoint_url, timeout=10)
response.raise_for_status() # Raise an HTTPError for bad responses
if response.headers["content-type"].lower() == "application/json":
tenants = [tenant["tenantId"] for tenant in response.json()]
else:
logger.warning("Response is not in JSON format.")
except Exception as e:
logger.warning("An unexpected error occurred:", e)
return tenants
def get_tenant_created_queue_name(self, settings: Dynaconf):
return self.get_queue_name_with_suffix(
suffix=settings.rabbitmq.tenant_created_event_queue_suffix, pod_name=settings.kubernetes.pod_name
)
def get_tenant_deleted_queue_name(self, settings: Dynaconf):
return self.get_queue_name_with_suffix(
suffix=settings.rabbitmq.tenant_deleted_event_queue_suffix, pod_name=settings.kubernetes.pod_name
)
def get_tenant_events_dlq_name(self, settings: Dynaconf):
return self.get_queue_name_with_suffix(
suffix=settings.rabbitmq.tenant_event_dlq_suffix, pod_name=settings.kubernetes.pod_name
)
def get_queue_name_with_suffix(self, suffix: str, pod_name: str):
if not self.use_default_queue_name() and pod_name:
return f"{pod_name}{suffix}"
return self.get_default_queue_name()
def use_default_queue_name(self):
return False
def get_default_queue_name(self):
raise NotImplementedError("Queue name method not implemented")
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger)
def establish_connection(self):
# TODO: set sensible retry parameters
if self.connection and self.connection.is_open:
logger.debug("Connection to RabbitMQ already established.")
return
logger.info("Establishing connection to RabbitMQ...")
self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
logger.debug("Opening channel...")
self.channel = self.connection.channel()
self.channel.basic_qos(prefetch_count=1)
args = {
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.tenant_events_dlq_name,
}
### Declare exchanges for tenants and responses
self.channel.exchange_declare(exchange=self.tenant_exchange_name, exchange_type="topic")
self.channel.exchange_declare(exchange=self.service_request_exchange_name, exchange_type="direct")
self.channel.exchange_declare(exchange=self.service_response_exchange_name, exchange_type="direct")
self.channel.queue_declare(self.tenant_created_queue_name, arguments=args, auto_delete=False, durable=True)
self.channel.queue_declare(self.tenant_deleted_queue_name, arguments=args, auto_delete=False, durable=True)
self.channel.queue_bind(
exchange=self.tenant_exchange_name, queue=self.tenant_created_queue_name, routing_key="tenant.created"
)
self.channel.queue_bind(
exchange=self.tenant_exchange_name, queue=self.tenant_deleted_queue_name, routing_key="tenant.delete"
)
for tenant_id in self.tenant_ids:
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_declare(
queue=queue_name,
durable=True,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.service_dlq_name,
"x-expires": self.queue_expiration_time, # TODO: check if necessary
"x-max-priority": 2,
},
)
self.channel.queue_bind(
queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id
)
logger.info("Connection to RabbitMQ established, channel open.")
def is_ready(self):
self.establish_connection()
return self.channel.is_open
@retry(exceptions=pika.exceptions.AMQPConnectionError, tries=3, delay=5, jitter=(1, 3), logger=logger)
def start_sequential_consume(self, message_processor: Callable):
self.establish_connection()
self._consuming = True
try:
while self._consuming:
for tenant_id in self.tenant_ids:
queue_name = self.service_queue_prefix + "_" + tenant_id
method_frame, properties, body = self.channel.basic_get(queue_name)
if method_frame:
on_message_callback = self._make_on_message_callback(message_processor, tenant_id)
on_message_callback(self.channel, method_frame, properties, body)
else:
logger.debug("No message returned")
time.sleep(self.connection_sleep)
### Handle tenant events
self.check_tenant_created_queue()
self.check_tenant_deleted_queue()
except KeyboardInterrupt:
logger.info("Exiting...")
finally:
self.stop_consuming()
def check_tenant_created_queue(self):
while True:
method_frame, properties, body = self.channel.basic_get(self.tenant_created_queue_name)
if method_frame:
self.channel.basic_ack(delivery_tag=method_frame.delivery_tag)
message = json.loads(body)
tenant_id = message["tenantId"]
self.on_tenant_created(tenant_id)
else:
logger.debug("No more tenant created events.")
break
def check_tenant_deleted_queue(self):
while True:
method_frame, properties, body = self.channel.basic_get(self.tenant_deleted_queue_name)
if method_frame:
self.channel.basic_ack(delivery_tag=method_frame.delivery_tag)
message = json.loads(body)
tenant_id = message["tenantId"]
self.on_tenant_deleted(tenant_id)
else:
logger.debug("No more tenant deleted events.")
break
def on_tenant_created(self, tenant_id: str):
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_declare(
queue=queue_name,
durable=True,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.service_dlq_name,
"x-expires": self.queue_expiration_time, # TODO: check if necessary
},
)
self.channel.queue_bind(queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id)
self.tenant_ids.append(tenant_id)
def on_tenant_deleted(self, tenant_id: str):
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_unbind(queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id)
self.channel.queue_delete(queue_name)
self.tenant_ids.remove(tenant_id)
def stop_consuming(self):
self._consuming = False
if self.channel and self.channel.is_open:
logger.info("Stopping consuming...")
self.channel.stop_consuming()
logger.info("Closing channel...")
self.channel.close()
if self.connection and self.connection.is_open:
logger.info("Closing connection to RabbitMQ...")
self.connection.close()
def publish_message_to_input_queue(
self, tenant_id: str, message: Union[str, bytes, dict], properties: pika.BasicProperties = None
):
if isinstance(message, str):
message = message.encode("utf-8")
elif isinstance(message, dict):
message = json.dumps(message).encode("utf-8")
self.establish_connection()
self.channel.basic_publish(
exchange=self.service_request_exchange_name,
routing_key=tenant_id,
properties=properties,
body=message,
)
logger.info(f"Published message to queue {tenant_id}.")
def purge_queues(self):
self.establish_connection()
try:
self.channel.queue_purge(self.tenant_created_queue_name)
self.channel.queue_purge(self.tenant_deleted_queue_name)
for tenant_id in self.tenant_ids:
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_purge(queue_name)
logger.info("Queues purged.")
except pika.exceptions.ChannelWrongStateError:
pass
def get_message_from_output_queue(self, queue: str):
self.establish_connection()
return self.channel.basic_get(queue, auto_ack=True)
def _make_on_message_callback(self, message_processor: MessageProcessor, tenant_id: str):
def process_message_body_and_await_result(unpacked_message_body):
# Processing the message in a separate thread is necessary for the main thread pika client to be able to
# process data events (e.g. heartbeats) while the message is being processed.
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.info("Processing payload in separate thread.")
future = thread_pool_executor.submit(message_processor, unpacked_message_body)
return future.result()
def on_message_callback(channel, method, properties, body):
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
if method.redelivered:
logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.")
channel.basic_nack(method.delivery_tag, requeue=False)
return
if body.decode("utf-8") == "STOP":
logger.info(f"Received stop signal, stopping consuming...")
channel.basic_ack(delivery_tag=method.delivery_tag)
self.stop_consuming()
return
try:
filtered_message_headers = (
{k: v for k, v in properties.headers.items() if k.lower().startswith("x-")}
if properties.headers
else {}
)
logger.debug(f"Processing message with {filtered_message_headers=}.")
result: dict = (
process_message_body_and_await_result({**json.loads(body), **filtered_message_headers}) or {}
)
channel.basic_publish(
exchange=self.service_response_exchange_name,
routing_key=tenant_id,
body=json.dumps(result).encode(),
properties=pika.BasicProperties(headers=filtered_message_headers),
)
logger.info(f"Published result to queue {tenant_id}.")
channel.basic_ack(delivery_tag=method.delivery_tag)
logger.debug(f"Message with {method.delivery_tag=} acknowledged.")
except FileNotFoundError as e:
logger.warning(f"{e}, declining message with {method.delivery_tag=}.")
channel.basic_nack(method.delivery_tag, requeue=False)
except Exception:
logger.warning(f"Failed to process message with {method.delivery_tag=}, declining...", exc_info=True)
channel.basic_nack(method.delivery_tag, requeue=False)
raise
return on_message_callback
def _handle_stop_signal(self, signum, *args, **kwargs):
logger.info(f"Received signal {signum}, stopping consuming...")
self.stop_consuming()
sys.exit(0)

View File

@ -5,7 +5,8 @@ from operator import itemgetter
from kn_utils.logging import logger
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
from pyinfra.queue.manager import QueueManager
# from pyinfra.queue.manager import QueueManager
from pyinfra.queue.sequential_tenants import QueueManager
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
settings = load_settings(local_pyinfra_root_path / "config/")
@ -41,7 +42,7 @@ def main():
message = upload_json_and_make_message_body()
queue_manager.publish_message_to_input_queue(message)
queue_manager.publish_message_to_input_queue(tenant_id="redaction", message=message)
logger.info(f"Put {message} on {settings.rabbitmq.input_queue}.")
storage = get_s3_storage_from_settings(settings)