feat: wip for multiple tenants

This commit is contained in:
Jonathan Kössler 2024-07-03 17:51:47 +02:00
parent 30330937ce
commit c81d967aee
3 changed files with 129 additions and 178 deletions

View File

@ -1,11 +1,13 @@
from dynaconf import Dynaconf
from fastapi import FastAPI
from kn_utils.logging import logger
import multiprocessing
from threading import Thread
from pyinfra.config.loader import get_pyinfra_validators, validate_settings
from pyinfra.queue.callback import Callback
# from pyinfra.queue.manager import QueueManager
from pyinfra.queue.sequential_tenants import QueueManager
from pyinfra.queue.threaded_tenants import ServiceQueueManager, TenantQueueManager
from pyinfra.utils.opentelemetry import instrument_pika, setup_trace, instrument_app
from pyinfra.webserver.prometheus import (
add_prometheus_endpoint,
@ -35,7 +37,8 @@ def start_standard_queue_consumer(
app = app or FastAPI()
queue_manager = QueueManager(settings)
tenant_manager = TenantQueueManager(settings)
service_manager = ServiceQueueManager(settings)
if settings.metrics.prometheus.enabled:
logger.info("Prometheus metrics enabled.")
@ -48,10 +51,18 @@ def start_standard_queue_consumer(
instrument_pika()
instrument_app(app)
app = add_health_check_endpoint(app, queue_manager.is_ready)
# app = add_health_check_endpoint(app, queue_manager.is_ready)
app = add_health_check_endpoint(app, service_manager.is_ready)
webserver_thread = create_webserver_thread_from_settings(app, settings)
webserver_thread.start()
# queue_manager.start_consuming(callback)
queue_manager.start_sequential_consume(callback)
# queue_manager.start_sequential_consume(callback)
# p1 = multiprocessing.Process(target=tenant_manager.start_consuming, daemon=True)
# p2 = multiprocessing.Process(target=service_manager.start_sequential_consume, kwargs={"callback":callback}, daemon=True)
thread = Thread(target=tenant_manager.start_consuming, daemon=True)
thread.start()
# p1.start()
# p2.start()
service_manager.start_sequential_consume(callback)

View File

@ -117,6 +117,7 @@ class QueueManager:
return
logger.info("Establishing connection to RabbitMQ...")
logger.info(self.__class__.__name__)
self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
logger.debug("Opening channel...")

View File

@ -1,17 +1,18 @@
import atexit
import asyncio
import concurrent.futures
import pika
import queue
import json
import logging
import signal
import sys
import requests
import time
import pika.exceptions
from dynaconf import Dynaconf
from typing import Callable, Union
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
from kn_utils.logging import logger
from pika.adapters.select_connection import SelectConnection
from pika.channel import Channel
from retry import retry
@ -25,14 +26,14 @@ MessageProcessor = Callable[[dict], dict]
class BaseQueueManager:
tenant_ids = []
tenant_exchange = queue.Queue()
def __init__(self, settings: Dynaconf):
validate_settings(settings, queue_manager_validators)
self.connection_parameters = self.create_connection_parameters(settings)
self.connection = None
self.channel = None
self.connection: Union[BlockingConnection, None] = None
self.channel: Union[BlockingChannel, None] = None
self.connection_sleep = settings.rabbitmq.connection_sleep
self.queue_expiration_time = settings.rabbitmq.queue_expiration_time
self.tenant_exchange_name = settings.rabbitmq.tenant_exchange_name
@ -52,22 +53,6 @@ class BaseQueueManager:
}
return pika.ConnectionParameters(**pika_connection_params)
# @retry(tries=3, delay=5, jitter=(1, 3), logger=logger)
# def establish_connection(self):
# if self.connection and self.connection.is_open:
# logger.debug("Connection to RabbitMQ already established.")
# return
# logger.info("Establishing connection to RabbitMQ...")
# self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
# logger.debug("Opening channel...")
# self.channel = self.connection.channel()
# self.channel.basic_qos(prefetch_count=1)
# self.initialize_queues()
# logger.info("Connection to RabbitMQ established, channel open.")
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger)
def establish_connection(self):
if self.connection and self.connection.is_open:
@ -75,69 +60,15 @@ class BaseQueueManager:
return
logger.info("Establishing connection to RabbitMQ...")
return SelectConnection(parameters=self.connection_parameters,
on_open_callback=self.on_connection_open,
on_open_error_callback=self.on_connection_open_error,
on_close_callback=self.on_connection_close)
def close_connection(self):
# self._consuming = False
if self.connection.is_closing or self.connection.is_closed:
logger.info('Connection is closing or already closed')
else:
logger.info('Closing connection')
self.connection.close()
def on_connection_open(self, unused_connection):
logger.debug("Connection opened")
self.open_channel()
logger.info(self.__class__.__name__)
self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
def open_channel(self):
"""Open a new channel with RabbitMQ by issuing the Channel.Open RPC
command. When RabbitMQ responds that the channel is open, the
on_channel_open callback will be invoked by pika.
"""
logger.debug('Creating a new channel')
self.connection.channel(on_open_callback=self.on_channel_open)
def on_connection_open_error(self, unused_connection, err):
logger.error(f"Connection open failed, reopening in {self.connection_sleep} seconds: {err}")
self.connection.ioloop.call_later(self.connection_sleep, self.establish_connection)
def on_connection_close(self, unused_connection, reason):
logger.warning(f"Connection closed, reopening in {self.connection_sleep} seconds: {reason}")
self.connection.ioloop.call_later(self.connection_sleep, self.establish_connection)
def on_channel_open(self, channel):
logger.debug("Channel opened")
self.channel = channel
# self.add_on_channel_close_callback()
logger.debug("Opening channel...")
self.channel = self.connection.channel()
self.channel.basic_qos(prefetch_count=1)
self.initialize_queues()
# def add_on_channel_close_callback(self):
# """This method tells pika to call the on_channel_closed method if
# RabbitMQ unexpectedly closes the channel.
# """
# logger.debug('Adding channel close callback')
# self.channel.add_on_close_callback(self.on_channel_closed)
# def on_channel_closed(self, channel, reason):
# """Invoked by pika when RabbitMQ unexpectedly closes the channel.
# Channels are usually closed if you attempt to do something that
# violates the protocol, such as re-declare an exchange or queue with
# different parameters. In this case, we'll close the connection
# to shutdown the object.
# :param pika.channel.Channel: The closed channel
# :param Exception reason: why the channel was closed
# """
# logger.warning('Channel %i was closed: %s', channel, reason)
# self.close_connection()
logger.info("Connection to RabbitMQ established, channel open.")
def is_ready(self):
self.establish_connection()
@ -170,11 +101,6 @@ class TenantQueueManager(BaseQueueManager):
self.tenant_created_queue_name = self.get_tenant_created_queue_name(settings)
self.tenant_deleted_queue_name = self.get_tenant_deleted_queue_name(settings)
self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings)
self.event_handlers = {"tenant_created": [], "tenant_deleted": []}
self.get_initial_tenant_ids(
tenant_endpoint_url=settings.storage.tenant_server.endpoint
)
def initialize_queues(self):
self.channel.exchange_declare(exchange=self.tenant_exchange_name, exchange_type="topic")
@ -206,29 +132,20 @@ class TenantQueueManager(BaseQueueManager):
self.channel.queue_bind(
exchange=self.tenant_exchange_name, queue=self.tenant_deleted_queue_name, routing_key="tenant.delete"
)
@retry(exceptions=pika.exceptions.AMQPConnectionError, tries=3, delay=5, jitter=(1, 3), logger=logger)
def start_consuming(self):
self.channel.basic_consume(queue=self.tenant_created_queue_name, on_message_callback=self.on_tenant_created)
self.channel.basic_consume(queue=self.tenant_deleted_queue_name, on_message_callback=self.on_tenant_deleted)
def start(self):
self.connection = self.establish_connection()
if self.connection is not None:
self.connection.ioloop.start()
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger, exceptions=requests.exceptions.HTTPError)
def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> list:
try:
response = requests.get(tenant_endpoint_url, timeout=10)
response.raise_for_status() # Raise an HTTPError for bad responses
if response.headers["content-type"].lower() == "application/json":
tenants = [tenant["tenantId"] for tenant in response.json()]
else:
logger.warning("Response is not in JSON format.")
except Exception as e:
logger.warning("An unexpected error occurred:", e)
self.tenant_ids.extend(tenants)
self.establish_connection()
self.channel.basic_consume(queue=self.tenant_created_queue_name, on_message_callback=self.on_tenant_created)
self.channel.basic_consume(queue=self.tenant_deleted_queue_name, on_message_callback=self.on_tenant_deleted)
self.channel.start_consuming()
except Exception:
logger.error("An unexpected error occurred while consuming messages. Consuming will stop.", exc_info=True)
raise
finally:
self.stop_consuming()
def get_tenant_created_queue_name(self, settings: Dynaconf):
return self.get_queue_name_with_suffix(
@ -259,52 +176,31 @@ class TenantQueueManager(BaseQueueManager):
def on_tenant_created(self, ch: Channel, method, properties, body):
logger.info("Received tenant created event")
message = json.loads(body)
logger.info(f"Tenant Created: {message}")
ch.basic_ack(delivery_tag=method.delivery_tag)
# TODO: test callback
tenant_id = body["tenantId"]
self.tenant_ids.append(tenant_id)
self._trigger_event("tenant_created", tenant_id)
tenant_id = message["tenantId"]
self.tenant_exchange.put(("create", tenant_id))
def on_tenant_deleted(self, ch: Channel, method, properties, body):
logger.info("Received tenant deleted event")
message = json.loads(body)
logger.info(f"Tenant Deleted: {message}")
ch.basic_ack(delivery_tag=method.delivery_tag)
# TODO: test callback
tenant_id = body["tenantId"]
self.tenant_ids.remove(tenant_id)
self._trigger_event("tenant_deleted", tenant_id)
def _trigger_event(self, event_type, tenant_id):
handler = self.event_handlers.get(event_type)
if handler:
try:
handler(tenant_id)
except Exception as e:
logger.error(f"Error in event handler for {event_type}: {e}", exc_info=True)
def add_event_handler(self, event_type: str, handler: Callable[[str], None]):
if event_type in self.event_handlers:
self.event_handlers[event_type] = handler
else:
logger.warning(f"Unknown event type: {event_type}")
tenant_id = message["tenantId"]
self.tenant_exchange.put(("delete", tenant_id))
def purge_queues(self):
self.establish_connection()
try:
self.channel.queue_purge(self.tenant_created_queue_name)
self.channel.queue_purge(self.tenant_deleted_queue_name)
self.channel.queue_purge(self.tenant_events_dlq_name)
logger.info("Queues purged.")
except pika.exceptions.ChannelWrongStateError:
pass
class ServiceQueueManager(BaseQueueManager):
def __init__(self, settings: Dynaconf, tenant_manager: TenantQueueManager):
def __init__(self, settings: Dynaconf):
super().__init__(settings)
self.service_request_exchange_name = settings.rabbitmq.service_request_exchange_name
@ -312,8 +208,9 @@ class ServiceQueueManager(BaseQueueManager):
self.service_queue_prefix = settings.rabbitmq.service_request_queue_prefix
self.service_dlq_name = settings.rabbitmq.service_dlq_name
tenant_manager.add_event_handler("tenant_created", self.add_tenant_queue)
tenant_manager.add_event_handler("tenant_deleted", self.delete_tenant_queue)
self.tenant_ids = self.get_initial_tenant_ids(tenant_endpoint_url=settings.storage.tenant_server.endpoint)
self._consuming = False
def initialize_queues(self):
self.channel.exchange_declare(exchange=self.service_request_exchange_name, exchange_type="direct")
@ -331,24 +228,65 @@ class ServiceQueueManager(BaseQueueManager):
"x-max-priority": 2,
},
)
self.channel.queue_bind(queue_name, self.service_request_exchange_name)
def start_consuming(self, message_processor: Callable):
self.connection = self.establish_connection()
for tenant_id in self.tenant_ids:
queue_name = self.service_queue_prefix + "_" + tenant_id
message_callback = self._make_on_message_callback(message_processor=message_processor, tenant_id=tenant_id)
self.channel.basic_consume(
queue=queue_name,
on_message_callback=message_callback,
self.channel.queue_bind(
queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id
)
logger.info(f"Starting to consume messages for queue {queue_name}...")
# self.channel.start_consuming()
if self.connection is not None:
self.connection.ioloop.start()
else:
logger.info("Connection is None, cannot start ioloop")
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger, exceptions=requests.exceptions.HTTPError)
def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> list:
try:
response = requests.get(tenant_endpoint_url, timeout=10)
response.raise_for_status() # Raise an HTTPError for bad responses
if response.headers["content-type"].lower() == "application/json":
tenants = [tenant["tenantId"] for tenant in response.json()]
else:
logger.warning("Response is not in JSON format.")
except Exception as e:
logger.warning("An unexpected error occurred:", e)
return tenants
@retry(exceptions=pika.exceptions.AMQPConnectionError, tries=3, delay=5, jitter=(1, 3), logger=logger)
def start_sequential_consume(self, message_processor: Callable):
self.establish_connection()
self._consuming = True
try:
while self._consuming:
for tenant_id in self.tenant_ids:
queue_name = self.service_queue_prefix + "_" + tenant_id
method_frame, properties, body = self.channel.basic_get(queue_name)
if method_frame:
on_message_callback = self._make_on_message_callback(message_processor, tenant_id)
on_message_callback(self.channel, method_frame, properties, body)
else:
logger.debug("No message returned")
time.sleep(self.connection_sleep)
### Handle tenant events
self.check_tenant_exchange()
except KeyboardInterrupt:
logger.info("Exiting...")
finally:
self.stop_consuming()
def check_tenant_exchange(self):
while True:
try:
event, tenant = self.tenant_exchange.get(block=False)
if event == "create":
self.on_tenant_created(tenant)
elif event == "delete":
self.on_tenant_deleted(tenant)
else:
break
except Exception:
logger.debug("No tenant exchange events.")
break
def publish_message_to_input_queue(
self, tenant_id: str, message: Union[str, bytes, dict], properties: pika.BasicProperties = None
@ -377,7 +315,7 @@ class ServiceQueueManager(BaseQueueManager):
except pika.exceptions.ChannelWrongStateError:
pass
def add_tenant_queue(self, tenant_id: str):
def on_tenant_created(self, tenant_id: str):
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_declare(
queue=queue_name,
@ -388,18 +326,16 @@ class ServiceQueueManager(BaseQueueManager):
"x-expires": self.queue_expiration_time, # TODO: check if necessary
},
)
self.channel.queue_bind(queue=queue_name, exchange=self.service_request_exchange_name)
# TODO: this is likely not possible due to blocking connection
message_callback = self._make_on_message_callback(message_processor=MessageProcessor, tenant_id=tenant_id)
self.channel.basic_consume(
queue=queue_name,
on_message_callback=message_callback,
)
self.channel.queue_bind(queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id)
self.tenant_ids.append(tenant_id)
logger.debug(f"Added tenant {tenant_id}.")
def delete_tenant_queue(self, tenant_id: str):
def on_tenant_deleted(self, tenant_id: str):
queue_name = self.service_queue_prefix + "_" + tenant_id
self.channel.queue_unbind(queue_name, self.service_request_exchange_name)
self.channel.queue_unbind(queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id)
self.channel.queue_delete(queue_name)
self.tenant_ids.remove(tenant_id)
logger.debug(f"Deleted tenant {tenant_id}.")
def _make_on_message_callback(self, message_processor: MessageProcessor, tenant_id: str):
def process_message_body_and_await_result(unpacked_message_body):
@ -409,16 +345,7 @@ class ServiceQueueManager(BaseQueueManager):
logger.info("Processing payload in separate thread.")
future = thread_pool_executor.submit(message_processor, unpacked_message_body)
# TODO: This block is probably not necessary, but kept since the implications of removing it are
# unclear. Remove it in a future iteration where less changes are being made to the code base.
# while future.running():
# logger.debug("Waiting for payload processing to finish...")
# self.connection.sleep(self.connection_sleep)
loop = asyncio.get_event_loop()
return loop.run_in_executor(None, future.result)
# return future.result()
return future.result()
def on_message_callback(channel, method, properties, body):
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
@ -429,7 +356,7 @@ class ServiceQueueManager(BaseQueueManager):
return
if body.decode("utf-8") == "STOP":
logger.info("Received stop signal, stopping consuming...")
logger.info(f"Received stop signal, stopping consuming...")
channel.basic_ack(delivery_tag=method.delivery_tag)
self.stop_consuming()
return
@ -446,7 +373,7 @@ class ServiceQueueManager(BaseQueueManager):
)
channel.basic_publish(
exchange=self.service_request_exchange_name,
exchange=self.service_response_exchange_name,
routing_key=tenant_id,
body=json.dumps(result).encode(),
properties=pika.BasicProperties(headers=filtered_message_headers),
@ -463,4 +390,16 @@ class ServiceQueueManager(BaseQueueManager):
channel.basic_nack(method.delivery_tag, requeue=False)
raise
return on_message_callback
return on_message_callback
def stop_consuming(self):
self._consuming = False
if self.channel and self.channel.is_open:
logger.info("Stopping consuming...")
self.channel.stop_consuming()
logger.info("Closing channel...")
self.channel.close()
if self.connection and self.connection.is_open:
logger.info("Closing connection to RabbitMQ...")
self.connection.close()