refactor: cleanup codebase
This commit is contained in:
parent
28451e8f8f
commit
1520e96287
@ -1,343 +0,0 @@
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import requests
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from typing import Callable, Union
|
||||
|
||||
import pika
|
||||
import pika.exceptions
|
||||
from dynaconf import Dynaconf
|
||||
from kn_utils.logging import logger
|
||||
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.config.loader import validate_settings
|
||||
from pyinfra.config.validators import queue_manager_validators
|
||||
|
||||
logger.set_level("DEBUG")
|
||||
pika_logger = logging.getLogger("pika")
|
||||
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
|
||||
|
||||
MessageProcessor = Callable[[dict], dict]
|
||||
|
||||
|
||||
class QueueManager:
|
||||
def __init__(self, settings: Dynaconf):
|
||||
validate_settings(settings, queue_manager_validators)
|
||||
|
||||
self.tenant_created_queue_name = self.get_tenant_created_queue_name(settings)
|
||||
self.tenant_deleted_queue_name = self.get_tenant_deleted_queue_name(settings)
|
||||
self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings)
|
||||
|
||||
self.connection_sleep = settings.rabbitmq.connection_sleep
|
||||
self.queue_expiration_time = settings.rabbitmq.queue_expiration_time
|
||||
|
||||
self.tenant_exchange_name = settings.rabbitmq.tenant_exchange_name
|
||||
self.service_request_exchange_name = settings.rabbitmq.service_request_exchange_name
|
||||
self.service_response_exchange_name = settings.rabbitmq.service_response_exchange_name
|
||||
|
||||
self.service_queue_prefix = settings.rabbitmq.service_request_queue_prefix
|
||||
self.service_dlq_name = settings.rabbitmq.service_dlq_name
|
||||
|
||||
self.connection_parameters = self.create_connection_parameters(settings)
|
||||
|
||||
self.connection: Union[BlockingConnection, None] = None
|
||||
self.channel: Union[BlockingChannel, None] = None
|
||||
|
||||
self.tenant_ids = self.get_initial_tenant_ids(tenant_endpoint_url=settings.storage.tenant_server.endpoint)
|
||||
|
||||
self._consuming = False
|
||||
|
||||
atexit.register(self.stop_consuming)
|
||||
signal.signal(signal.SIGTERM, self._handle_stop_signal)
|
||||
signal.signal(signal.SIGINT, self._handle_stop_signal)
|
||||
|
||||
@staticmethod
|
||||
def create_connection_parameters(settings: Dynaconf):
|
||||
credentials = pika.PlainCredentials(username=settings.rabbitmq.username, password=settings.rabbitmq.password)
|
||||
pika_connection_params = {
|
||||
"host": settings.rabbitmq.host,
|
||||
"port": settings.rabbitmq.port,
|
||||
"credentials": credentials,
|
||||
"heartbeat": settings.rabbitmq.heartbeat,
|
||||
}
|
||||
|
||||
return pika.ConnectionParameters(**pika_connection_params)
|
||||
|
||||
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger, exceptions=requests.exceptions.HTTPError)
|
||||
def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> list:
|
||||
try:
|
||||
response = requests.get(tenant_endpoint_url, timeout=10)
|
||||
response.raise_for_status() # Raise an HTTPError for bad responses
|
||||
|
||||
if response.headers["content-type"].lower() == "application/json":
|
||||
tenants = [tenant["tenantId"] for tenant in response.json()]
|
||||
else:
|
||||
logger.warning("Response is not in JSON format.")
|
||||
except Exception as e:
|
||||
logger.warning("An unexpected error occurred:", e)
|
||||
|
||||
return tenants
|
||||
|
||||
def get_tenant_created_queue_name(self, settings: Dynaconf):
|
||||
return self.get_queue_name_with_suffix(
|
||||
suffix=settings.rabbitmq.tenant_created_event_queue_suffix, pod_name=settings.kubernetes.pod_name
|
||||
)
|
||||
|
||||
def get_tenant_deleted_queue_name(self, settings: Dynaconf):
|
||||
return self.get_queue_name_with_suffix(
|
||||
suffix=settings.rabbitmq.tenant_deleted_event_queue_suffix, pod_name=settings.kubernetes.pod_name
|
||||
)
|
||||
|
||||
def get_tenant_events_dlq_name(self, settings: Dynaconf):
|
||||
return self.get_queue_name_with_suffix(
|
||||
suffix=settings.rabbitmq.tenant_event_dlq_suffix, pod_name=settings.kubernetes.pod_name
|
||||
)
|
||||
|
||||
def get_queue_name_with_suffix(self, suffix: str, pod_name: str):
|
||||
if not self.use_default_queue_name() and pod_name:
|
||||
return f"{pod_name}{suffix}"
|
||||
return self.get_default_queue_name()
|
||||
|
||||
def use_default_queue_name(self):
|
||||
return False
|
||||
|
||||
def get_default_queue_name(self):
|
||||
raise NotImplementedError("Queue name method not implemented")
|
||||
|
||||
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger)
|
||||
def establish_connection(self):
|
||||
# TODO: set sensible retry parameters
|
||||
if self.connection and self.connection.is_open:
|
||||
logger.debug("Connection to RabbitMQ already established.")
|
||||
return
|
||||
|
||||
logger.info("Establishing connection to RabbitMQ...")
|
||||
logger.info(self.__class__.__name__)
|
||||
self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
|
||||
|
||||
logger.debug("Opening channel...")
|
||||
self.channel = self.connection.channel()
|
||||
self.channel.basic_qos(prefetch_count=1)
|
||||
|
||||
args = {
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.tenant_events_dlq_name,
|
||||
}
|
||||
|
||||
### Declare exchanges for tenants and responses
|
||||
self.channel.exchange_declare(exchange=self.tenant_exchange_name, exchange_type="topic")
|
||||
self.channel.exchange_declare(exchange=self.service_request_exchange_name, exchange_type="direct")
|
||||
self.channel.exchange_declare(exchange=self.service_response_exchange_name, exchange_type="direct")
|
||||
|
||||
self.channel.queue_declare(self.tenant_created_queue_name, arguments=args, auto_delete=False, durable=True)
|
||||
self.channel.queue_declare(self.tenant_deleted_queue_name, arguments=args, auto_delete=False, durable=True)
|
||||
|
||||
self.channel.queue_bind(
|
||||
exchange=self.tenant_exchange_name, queue=self.tenant_created_queue_name, routing_key="tenant.created"
|
||||
)
|
||||
self.channel.queue_bind(
|
||||
exchange=self.tenant_exchange_name, queue=self.tenant_deleted_queue_name, routing_key="tenant.delete"
|
||||
)
|
||||
|
||||
for tenant_id in self.tenant_ids:
|
||||
queue_name = self.service_queue_prefix + "_" + tenant_id
|
||||
self.channel.queue_declare(
|
||||
queue=queue_name,
|
||||
durable=True,
|
||||
arguments={
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.service_dlq_name,
|
||||
"x-expires": self.queue_expiration_time, # TODO: check if necessary
|
||||
"x-max-priority": 2,
|
||||
},
|
||||
)
|
||||
self.channel.queue_bind(
|
||||
queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id
|
||||
)
|
||||
|
||||
logger.info("Connection to RabbitMQ established, channel open.")
|
||||
|
||||
def is_ready(self):
|
||||
self.establish_connection()
|
||||
return self.channel.is_open
|
||||
|
||||
@retry(exceptions=pika.exceptions.AMQPConnectionError, tries=3, delay=5, jitter=(1, 3), logger=logger)
|
||||
def start_sequential_consume(self, message_processor: Callable):
|
||||
|
||||
self.establish_connection()
|
||||
self._consuming = True
|
||||
|
||||
try:
|
||||
while self._consuming:
|
||||
for tenant_id in self.tenant_ids:
|
||||
queue_name = self.service_queue_prefix + "_" + tenant_id
|
||||
method_frame, properties, body = self.channel.basic_get(queue_name)
|
||||
if method_frame:
|
||||
on_message_callback = self._make_on_message_callback(message_processor, tenant_id)
|
||||
on_message_callback(self.channel, method_frame, properties, body)
|
||||
else:
|
||||
logger.debug("No message returned")
|
||||
time.sleep(self.connection_sleep)
|
||||
|
||||
### Handle tenant events
|
||||
self.check_tenant_created_queue()
|
||||
self.check_tenant_deleted_queue()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Exiting...")
|
||||
finally:
|
||||
self.stop_consuming()
|
||||
|
||||
def check_tenant_created_queue(self):
|
||||
while True:
|
||||
method_frame, properties, body = self.channel.basic_get(self.tenant_created_queue_name)
|
||||
if method_frame:
|
||||
self.channel.basic_ack(delivery_tag=method_frame.delivery_tag)
|
||||
message = json.loads(body)
|
||||
tenant_id = message["tenantId"]
|
||||
self.on_tenant_created(tenant_id)
|
||||
else:
|
||||
logger.debug("No more tenant created events.")
|
||||
break
|
||||
|
||||
def check_tenant_deleted_queue(self):
|
||||
while True:
|
||||
method_frame, properties, body = self.channel.basic_get(self.tenant_deleted_queue_name)
|
||||
if method_frame:
|
||||
self.channel.basic_ack(delivery_tag=method_frame.delivery_tag)
|
||||
message = json.loads(body)
|
||||
tenant_id = message["tenantId"]
|
||||
self.on_tenant_deleted(tenant_id)
|
||||
else:
|
||||
logger.debug("No more tenant deleted events.")
|
||||
break
|
||||
|
||||
def on_tenant_created(self, tenant_id: str):
|
||||
queue_name = self.service_queue_prefix + "_" + tenant_id
|
||||
self.channel.queue_declare(
|
||||
queue=queue_name,
|
||||
durable=True,
|
||||
arguments={
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.service_dlq_name,
|
||||
"x-expires": self.queue_expiration_time, # TODO: check if necessary
|
||||
},
|
||||
)
|
||||
self.channel.queue_bind(queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id)
|
||||
self.tenant_ids.append(tenant_id)
|
||||
|
||||
def on_tenant_deleted(self, tenant_id: str):
|
||||
queue_name = self.service_queue_prefix + "_" + tenant_id
|
||||
self.channel.queue_unbind(queue=queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id)
|
||||
self.channel.queue_delete(queue_name)
|
||||
self.tenant_ids.remove(tenant_id)
|
||||
|
||||
def stop_consuming(self):
|
||||
self._consuming = False
|
||||
if self.channel and self.channel.is_open:
|
||||
logger.info("Stopping consuming...")
|
||||
self.channel.stop_consuming()
|
||||
logger.info("Closing channel...")
|
||||
self.channel.close()
|
||||
|
||||
if self.connection and self.connection.is_open:
|
||||
logger.info("Closing connection to RabbitMQ...")
|
||||
self.connection.close()
|
||||
|
||||
def publish_message_to_input_queue(
|
||||
self, tenant_id: str, message: Union[str, bytes, dict], properties: pika.BasicProperties = None
|
||||
):
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
elif isinstance(message, dict):
|
||||
message = json.dumps(message).encode("utf-8")
|
||||
|
||||
self.establish_connection()
|
||||
self.channel.basic_publish(
|
||||
exchange=self.service_request_exchange_name,
|
||||
routing_key=tenant_id,
|
||||
properties=properties,
|
||||
body=message,
|
||||
)
|
||||
logger.info(f"Published message to queue {tenant_id}.")
|
||||
|
||||
def purge_queues(self):
|
||||
self.establish_connection()
|
||||
try:
|
||||
self.channel.queue_purge(self.tenant_created_queue_name)
|
||||
self.channel.queue_purge(self.tenant_deleted_queue_name)
|
||||
for tenant_id in self.tenant_ids:
|
||||
queue_name = self.service_queue_prefix + "_" + tenant_id
|
||||
self.channel.queue_purge(queue_name)
|
||||
logger.info("Queues purged.")
|
||||
except pika.exceptions.ChannelWrongStateError:
|
||||
pass
|
||||
|
||||
def get_message_from_output_queue(self, queue: str):
|
||||
self.establish_connection()
|
||||
return self.channel.basic_get(queue, auto_ack=True)
|
||||
|
||||
def _make_on_message_callback(self, message_processor: MessageProcessor, tenant_id: str):
|
||||
def process_message_body_and_await_result(unpacked_message_body):
|
||||
# Processing the message in a separate thread is necessary for the main thread pika client to be able to
|
||||
# process data events (e.g. heartbeats) while the message is being processed.
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
||||
logger.info("Processing payload in separate thread.")
|
||||
future = thread_pool_executor.submit(message_processor, unpacked_message_body)
|
||||
|
||||
return future.result()
|
||||
|
||||
def on_message_callback(channel, method, properties, body):
|
||||
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
|
||||
|
||||
if method.redelivered:
|
||||
logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.")
|
||||
channel.basic_nack(method.delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
if body.decode("utf-8") == "STOP":
|
||||
logger.info(f"Received stop signal, stopping consuming...")
|
||||
channel.basic_ack(delivery_tag=method.delivery_tag)
|
||||
self.stop_consuming()
|
||||
return
|
||||
|
||||
try:
|
||||
filtered_message_headers = (
|
||||
{k: v for k, v in properties.headers.items() if k.lower().startswith("x-")}
|
||||
if properties.headers
|
||||
else {}
|
||||
)
|
||||
logger.debug(f"Processing message with {filtered_message_headers=}.")
|
||||
result: dict = (
|
||||
process_message_body_and_await_result({**json.loads(body), **filtered_message_headers}) or {}
|
||||
)
|
||||
|
||||
channel.basic_publish(
|
||||
exchange=self.service_response_exchange_name,
|
||||
routing_key=tenant_id,
|
||||
body=json.dumps(result).encode(),
|
||||
properties=pika.BasicProperties(headers=filtered_message_headers),
|
||||
)
|
||||
logger.info(f"Published result to queue {tenant_id}.")
|
||||
|
||||
channel.basic_ack(delivery_tag=method.delivery_tag)
|
||||
logger.debug(f"Message with {method.delivery_tag=} acknowledged.")
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(f"{e}, declining message with {method.delivery_tag=}.")
|
||||
channel.basic_nack(method.delivery_tag, requeue=False)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to process message with {method.delivery_tag=}, declining...", exc_info=True)
|
||||
channel.basic_nack(method.delivery_tag, requeue=False)
|
||||
raise
|
||||
|
||||
return on_message_callback
|
||||
|
||||
def _handle_stop_signal(self, signum, *args, **kwargs):
|
||||
logger.info(f"Received signal {signum}, stopping consuming...")
|
||||
self.stop_consuming()
|
||||
sys.exit(0)
|
||||
@ -1,432 +0,0 @@
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import pika
|
||||
import queue
|
||||
import json
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import requests
|
||||
import time
|
||||
import threading
|
||||
import pika.exceptions
|
||||
from dynaconf import Dynaconf
|
||||
from typing import Callable, Union
|
||||
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
|
||||
from kn_utils.logging import logger
|
||||
from pika.channel import Channel
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.config.loader import validate_settings
|
||||
from pyinfra.config.validators import queue_manager_validators
|
||||
|
||||
logger.set_level("DEBUG")
|
||||
pika_logger = logging.getLogger("pika")
|
||||
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
|
||||
|
||||
MessageProcessor = Callable[[dict], dict]
|
||||
|
||||
|
||||
class BaseQueueManager:
|
||||
tenant_exchange_queue = queue.Queue()
|
||||
_connection = None
|
||||
_lock = threading.Lock()
|
||||
should_stop = threading.Event()
|
||||
|
||||
def __init__(self, settings: Dynaconf):
|
||||
validate_settings(settings, queue_manager_validators)
|
||||
|
||||
self.connection_parameters = self.create_connection_parameters(settings)
|
||||
self.connection: Union[BlockingConnection, None] = None
|
||||
self.channel: Union[BlockingChannel, None] = None
|
||||
self.connection_sleep = settings.rabbitmq.connection_sleep
|
||||
self.queue_expiration_time = settings.rabbitmq.queue_expiration_time
|
||||
self.tenant_exchange_name = settings.rabbitmq.tenant_exchange_name
|
||||
|
||||
atexit.register(self.stop_consuming)
|
||||
signal.signal(signal.SIGTERM, self._handle_stop_signal)
|
||||
signal.signal(signal.SIGINT, self._handle_stop_signal)
|
||||
|
||||
@staticmethod
|
||||
def create_connection_parameters(settings: Dynaconf) -> pika.ConnectionParameters:
|
||||
credentials = pika.PlainCredentials(username=settings.rabbitmq.username, password=settings.rabbitmq.password)
|
||||
pika_connection_params = {
|
||||
"host": settings.rabbitmq.host,
|
||||
"port": settings.rabbitmq.port,
|
||||
"credentials": credentials,
|
||||
"heartbeat": settings.rabbitmq.heartbeat,
|
||||
}
|
||||
return pika.ConnectionParameters(**pika_connection_params)
|
||||
|
||||
def get_connection(self) -> BlockingConnection:
|
||||
with self._lock:
|
||||
if not self._connection or self._connection.is_closed:
|
||||
self._connection = pika.BlockingConnection(self.connection_parameters)
|
||||
return self._connection
|
||||
|
||||
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger)
|
||||
def establish_connection(self) -> None:
|
||||
logger.info(f"Establishing connection to RabbitMQ for {self.__class__.__name__}...")
|
||||
self.connection = self.get_connection()
|
||||
if not self.channel or self.channel.is_closed:
|
||||
logger.debug("Opening channel...")
|
||||
self.channel = self.connection.channel()
|
||||
self.channel.basic_qos(prefetch_count=1)
|
||||
self.initialize_queues()
|
||||
logger.info(f"Connection to RabbitMQ established for {self.__class__.__name__}, channel open.")
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
self.establish_connection()
|
||||
return self.channel.is_open
|
||||
|
||||
def initialize_queues(self) -> None:
|
||||
raise NotImplementedError("Subclasses should implement this method")
|
||||
|
||||
def stop_consuming(self) -> None:
|
||||
if not self.should_stop.is_set():
|
||||
self.should_stop.set()
|
||||
if self.channel and self.channel.is_open:
|
||||
try:
|
||||
self.channel.stop_consuming()
|
||||
self.channel.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping consuming: {e}", exc_info=True)
|
||||
if self.connection and self.connection.is_open:
|
||||
try:
|
||||
self.connection.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing connection: {e}", exc_info=True)
|
||||
|
||||
def _handle_stop_signal(self, signum, *args, **kwargs) -> None:
|
||||
logger.info(f"Received signal {signum}, stopping consuming...")
|
||||
self.stop_consuming()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
class TenantQueueManager(BaseQueueManager):
|
||||
def __init__(self, settings: Dynaconf):
|
||||
super().__init__(settings)
|
||||
|
||||
self.tenant_created_queue_name = self.get_tenant_created_queue_name(settings)
|
||||
self.tenant_deleted_queue_name = self.get_tenant_deleted_queue_name(settings)
|
||||
self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings)
|
||||
|
||||
def initialize_queues(self) -> None:
|
||||
self.channel.exchange_declare(exchange=self.tenant_exchange_name, exchange_type="topic", durable=True)
|
||||
|
||||
self.channel.queue_declare(
|
||||
queue=self.tenant_created_queue_name,
|
||||
arguments={
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.tenant_events_dlq_name,
|
||||
},
|
||||
durable=True,
|
||||
)
|
||||
self.channel.queue_declare(
|
||||
queue=self.tenant_deleted_queue_name,
|
||||
arguments={
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.tenant_events_dlq_name,
|
||||
},
|
||||
durable=True,
|
||||
)
|
||||
self.channel.queue_declare(
|
||||
queue=self.tenant_events_dlq_name,
|
||||
durable=True,
|
||||
)
|
||||
|
||||
self.channel.queue_bind(
|
||||
exchange=self.tenant_exchange_name, queue=self.tenant_created_queue_name, routing_key="tenant.created"
|
||||
)
|
||||
self.channel.queue_bind(
|
||||
exchange=self.tenant_exchange_name, queue=self.tenant_deleted_queue_name, routing_key="tenant.delete"
|
||||
)
|
||||
|
||||
@retry(exceptions=pika.exceptions.AMQPConnectionError, tries=3, delay=5, jitter=(1, 3), logger=logger)
|
||||
def start_consuming(self) -> None:
|
||||
|
||||
try:
|
||||
self.establish_connection()
|
||||
self.channel.basic_consume(queue=self.tenant_created_queue_name, on_message_callback=self.on_tenant_created)
|
||||
self.channel.basic_consume(queue=self.tenant_deleted_queue_name, on_message_callback=self.on_tenant_deleted)
|
||||
self.channel.start_consuming()
|
||||
except Exception:
|
||||
logger.error("An unexpected error occurred while consuming messages. Consuming will stop.", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
self.stop_consuming()
|
||||
|
||||
def get_tenant_created_queue_name(self, settings: Dynaconf) -> str:
|
||||
return self.get_queue_name_with_suffix(
|
||||
suffix=settings.rabbitmq.tenant_created_event_queue_suffix, pod_name=settings.kubernetes.pod_name
|
||||
)
|
||||
|
||||
def get_tenant_deleted_queue_name(self, settings: Dynaconf) -> str:
|
||||
return self.get_queue_name_with_suffix(
|
||||
suffix=settings.rabbitmq.tenant_deleted_event_queue_suffix, pod_name=settings.kubernetes.pod_name
|
||||
)
|
||||
|
||||
def get_tenant_events_dlq_name(self, settings: Dynaconf) -> str:
|
||||
return self.get_queue_name_with_suffix(
|
||||
suffix=settings.rabbitmq.tenant_event_dlq_suffix, pod_name=settings.kubernetes.pod_name
|
||||
)
|
||||
|
||||
def get_queue_name_with_suffix(self, suffix: str, pod_name: str) -> str:
|
||||
if not self.use_default_queue_name() and pod_name:
|
||||
return f"{pod_name}{suffix}"
|
||||
return self.get_default_queue_name()
|
||||
|
||||
def use_default_queue_name(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_default_queue_name(self):
|
||||
raise NotImplementedError("Queue name method not implemented")
|
||||
|
||||
def on_tenant_created(self, ch: Channel, method, properties, body) -> None:
|
||||
logger.info("Received tenant created event")
|
||||
message = json.loads(body)
|
||||
ch.basic_ack(delivery_tag=method.delivery_tag)
|
||||
|
||||
tenant_id = message["tenantId"]
|
||||
self.tenant_exchange_queue.put(("create", tenant_id))
|
||||
|
||||
def on_tenant_deleted(self, ch: Channel, method, properties, body) -> None:
|
||||
logger.info("Received tenant deleted event")
|
||||
message = json.loads(body)
|
||||
ch.basic_ack(delivery_tag=method.delivery_tag)
|
||||
|
||||
tenant_id = message["tenantId"]
|
||||
self.tenant_exchange_queue.put(("delete", tenant_id))
|
||||
|
||||
def purge_queues(self) -> None:
|
||||
self.establish_connection()
|
||||
try:
|
||||
self.channel.queue_purge(self.tenant_created_queue_name)
|
||||
self.channel.queue_purge(self.tenant_deleted_queue_name)
|
||||
logger.info("Queues purged.")
|
||||
except pika.exceptions.ChannelWrongStateError:
|
||||
pass
|
||||
|
||||
def publish_message_to_tenant_created_queue(
|
||||
self, message: Union[str, bytes, dict], properties: pika.BasicProperties = None
|
||||
) -> None:
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
elif isinstance(message, dict):
|
||||
message = json.dumps(message).encode("utf-8")
|
||||
|
||||
self.establish_connection()
|
||||
self.channel.basic_publish(
|
||||
exchange=self.tenant_exchange_name,
|
||||
routing_key="tenant.created",
|
||||
properties=properties,
|
||||
body=message,
|
||||
)
|
||||
logger.info(f"Published message to queue {self.tenant_created_queue_name}.")
|
||||
|
||||
def publish_message_to_tenant_deleted_queue(
|
||||
self, message: Union[str, bytes, dict], properties: pika.BasicProperties = None
|
||||
) -> None:
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
elif isinstance(message, dict):
|
||||
message = json.dumps(message).encode("utf-8")
|
||||
|
||||
self.establish_connection()
|
||||
self.channel.basic_publish(
|
||||
exchange=self.tenant_exchange_name,
|
||||
routing_key="tenant.delete",
|
||||
properties=properties,
|
||||
body=message,
|
||||
)
|
||||
logger.info(f"Published message to queue {self.tenant_deleted_queue_name}.")
|
||||
|
||||
|
||||
class ServiceQueueManager(BaseQueueManager):
|
||||
def __init__(self, settings: Dynaconf):
|
||||
super().__init__(settings)
|
||||
|
||||
self.service_request_exchange_name = settings.rabbitmq.service_request_exchange_name
|
||||
self.service_response_exchange_name = settings.rabbitmq.service_response_exchange_name
|
||||
|
||||
self.service_request_queue_prefix = settings.rabbitmq.service_request_queue_prefix
|
||||
|
||||
self.service_dlq_name = settings.rabbitmq.service_dlq_name
|
||||
|
||||
self.tenant_ids = self.get_initial_tenant_ids(tenant_endpoint_url=settings.storage.tenant_server.endpoint)
|
||||
|
||||
def initialize_queues(self) -> None:
|
||||
self.channel.exchange_declare(exchange=self.service_request_exchange_name, exchange_type="direct", durable=True)
|
||||
self.channel.exchange_declare(exchange=self.service_response_exchange_name, exchange_type="direct", durable=True)
|
||||
|
||||
for tenant_id in self.tenant_ids:
|
||||
request_queue_name = f"{self.service_request_queue_prefix}_{tenant_id}"
|
||||
self.channel.queue_declare(
|
||||
queue=request_queue_name,
|
||||
durable=True,
|
||||
arguments={
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.service_dlq_name,
|
||||
"x-expires": self.queue_expiration_time, # TODO: check if necessary
|
||||
"x-max-priority": 2,
|
||||
},
|
||||
)
|
||||
self.channel.queue_bind(
|
||||
queue=request_queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id
|
||||
)
|
||||
|
||||
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger, exceptions=(requests.exceptions.HTTPError, requests.exceptions.ConnectionError))
|
||||
def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> list:
|
||||
response = requests.get(tenant_endpoint_url, timeout=10)
|
||||
response.raise_for_status() # Raise an HTTPError for bad responses
|
||||
|
||||
if response.headers["content-type"].lower() == "application/json":
|
||||
tenants = [tenant["tenantId"] for tenant in response.json()]
|
||||
return tenants
|
||||
return []
|
||||
|
||||
@retry(exceptions=pika.exceptions.AMQPConnectionError, tries=3, delay=5, jitter=(1, 3), logger=logger)
|
||||
def start_sequential_basic_get(self, message_processor: Callable) -> None:
|
||||
|
||||
self.establish_connection()
|
||||
try:
|
||||
while not self.should_stop.is_set():
|
||||
for tenant_id in self.tenant_ids:
|
||||
queue_name = f"{self.service_request_queue_prefix}_{tenant_id}"
|
||||
method_frame, properties, body = self.channel.basic_get(queue_name)
|
||||
if method_frame:
|
||||
logger.debug("PROCESSING MESSAGE")
|
||||
on_message_callback = self._make_on_message_callback(message_processor, tenant_id)
|
||||
on_message_callback(self.channel, method_frame, properties, body)
|
||||
else:
|
||||
logger.debug(f"No message returned for queue {queue_name}")
|
||||
# time.sleep(self.connection_sleep)
|
||||
time.sleep(0.1)
|
||||
|
||||
### Handle tenant events
|
||||
self.check_tenant_exchange()
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Exiting...")
|
||||
finally:
|
||||
self.stop_consuming()
|
||||
|
||||
def check_tenant_exchange(self) -> None:
|
||||
while not self.tenant_exchange_queue.empty():
|
||||
try:
|
||||
event, tenant = self.tenant_exchange_queue.get_nowait()
|
||||
if event == "create":
|
||||
self.on_tenant_created(tenant)
|
||||
elif event == "delete":
|
||||
self.on_tenant_deleted(tenant)
|
||||
except queue.Empty:
|
||||
# time.sleep(self.connection_sleep)
|
||||
break
|
||||
|
||||
|
||||
def publish_message_to_input_queue(
|
||||
self, tenant_id: str, message: Union[str, bytes, dict], properties: pika.BasicProperties = None
|
||||
) -> None:
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
elif isinstance(message, dict):
|
||||
message = json.dumps(message).encode("utf-8")
|
||||
|
||||
self.establish_connection()
|
||||
self.channel.basic_publish(
|
||||
exchange=self.service_request_exchange_name,
|
||||
routing_key=tenant_id,
|
||||
properties=properties,
|
||||
body=message,
|
||||
)
|
||||
logger.info(f"Published message to queue {tenant_id}.")
|
||||
|
||||
def purge_queues(self) -> None:
|
||||
self.establish_connection()
|
||||
try:
|
||||
for tenant_id in self.tenant_ids:
|
||||
request_queue_name = f"{self.service_request_queue_prefix}_{tenant_id}"
|
||||
self.channel.queue_purge(request_queue_name)
|
||||
logger.info("Queues purged.")
|
||||
except pika.exceptions.ChannelWrongStateError:
|
||||
pass
|
||||
|
||||
def on_tenant_created(self, tenant_id: str) -> None:
|
||||
request_queue_name = f"{self.service_request_queue_prefix}_{tenant_id}"
|
||||
self.channel.queue_declare(
|
||||
queue=request_queue_name,
|
||||
durable=True,
|
||||
arguments={
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.service_dlq_name,
|
||||
"x-expires": self.queue_expiration_time, # TODO: check if necessary
|
||||
},
|
||||
)
|
||||
self.channel.queue_bind(queue=request_queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id)
|
||||
|
||||
self.tenant_ids.append(tenant_id)
|
||||
logger.debug(f"Added tenant {tenant_id}.")
|
||||
|
||||
def on_tenant_deleted(self, tenant_id: str) -> None:
|
||||
request_queue_name = f"{self.service_request_queue_prefix}_{tenant_id}"
|
||||
self.channel.queue_unbind(queue=request_queue_name, exchange=self.service_request_exchange_name, routing_key=tenant_id)
|
||||
self.channel.queue_delete(request_queue_name)
|
||||
|
||||
self.tenant_ids.remove(tenant_id)
|
||||
logger.debug(f"Deleted tenant {tenant_id}.")
|
||||
|
||||
def _make_on_message_callback(self, message_processor: MessageProcessor, tenant_id: str) -> Callable:
|
||||
def process_message_body_and_await_result(unpacked_message_body):
|
||||
# Processing the message in a separate thread is necessary for the main thread pika client to be able to
|
||||
# process data events (e.g. heartbeats) while the message is being processed.
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
||||
logger.info("Processing payload in separate thread.")
|
||||
future = thread_pool_executor.submit(message_processor, unpacked_message_body)
|
||||
|
||||
return future.result()
|
||||
|
||||
def on_message_callback(channel, method, properties, body):
|
||||
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
|
||||
|
||||
if method.redelivered:
|
||||
logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.")
|
||||
channel.basic_nack(method.delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
if body.decode("utf-8") == "STOP":
|
||||
logger.info(f"Received stop signal, stopping consuming...")
|
||||
channel.basic_ack(delivery_tag=method.delivery_tag)
|
||||
self.stop_consuming()
|
||||
return
|
||||
|
||||
try:
|
||||
filtered_message_headers = (
|
||||
{k: v for k, v in properties.headers.items() if k.lower().startswith("x-")}
|
||||
if properties.headers
|
||||
else {}
|
||||
)
|
||||
logger.debug(f"Processing message with {filtered_message_headers=}.")
|
||||
result: dict = (
|
||||
process_message_body_and_await_result({**json.loads(body), **filtered_message_headers}) or {}
|
||||
)
|
||||
|
||||
channel.basic_publish(
|
||||
exchange=self.service_response_exchange_name,
|
||||
routing_key=tenant_id,
|
||||
body=json.dumps(result).encode(),
|
||||
properties=pika.BasicProperties(headers=filtered_message_headers),
|
||||
)
|
||||
logger.info(f"Published result to queue {tenant_id}.")
|
||||
|
||||
channel.basic_ack(delivery_tag=method.delivery_tag)
|
||||
logger.debug(f"Message with {method.delivery_tag=} acknowledged.")
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(f"{e}, declining message with {method.delivery_tag=}.")
|
||||
channel.basic_nack(method.delivery_tag, requeue=False)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to process message with {method.delivery_tag=}, declining...", exc_info=True)
|
||||
channel.basic_nack(method.delivery_tag, requeue=False)
|
||||
raise
|
||||
|
||||
return on_message_callback
|
||||
@ -1,100 +0,0 @@
|
||||
import gzip
|
||||
import json
|
||||
import time
|
||||
from operator import itemgetter
|
||||
from threading import Thread
|
||||
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
|
||||
from pyinfra.queue.threaded_tenants import ServiceQueueManager, TenantQueueManager
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
|
||||
|
||||
settings = load_settings(local_pyinfra_root_path / "config/")
|
||||
|
||||
|
||||
def upload_json_and_make_message_body(tenant_id: str):
|
||||
dossier_id, file_id, suffix = "dossier", "file", "json.gz"
|
||||
content = {
|
||||
"numberOfPages": 7,
|
||||
"sectionTexts": "data",
|
||||
}
|
||||
|
||||
object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}"
|
||||
data = gzip.compress(json.dumps(content).encode("utf-8"))
|
||||
|
||||
storage = get_s3_storage_from_settings(settings)
|
||||
if not storage.has_bucket():
|
||||
storage.make_bucket()
|
||||
storage.put_object(object_name, data)
|
||||
|
||||
message_body = {
|
||||
"tenantId": tenant_id,
|
||||
"dossierId": dossier_id,
|
||||
"fileId": file_id,
|
||||
"targetFileExtension": suffix,
|
||||
"responseFileExtension": f"result.{suffix}",
|
||||
}
|
||||
return message_body
|
||||
|
||||
|
||||
def tenant_event_message(tenant_id: str):
|
||||
return {"tenantId": tenant_id}
|
||||
|
||||
|
||||
def send_tenant_event(tenant_id: str, event_type: str):
|
||||
queue_manager = TenantQueueManager(settings)
|
||||
queue_manager.purge_queues()
|
||||
message = tenant_event_message(tenant_id)
|
||||
if event_type == "create":
|
||||
queue_manager.publish_message_to_tenant_created_queue(message=message)
|
||||
elif event_type == "delete":
|
||||
queue_manager.publish_message_to_tenant_deleted_queue(message=message)
|
||||
else:
|
||||
logger.warning(f"Event type '{event_type}' not known.")
|
||||
queue_manager.stop_consuming()
|
||||
|
||||
|
||||
def send_service_request(tenant_id: str):
|
||||
queue_manager = ServiceQueueManager(settings)
|
||||
queue_name = f"service_response_queue_{tenant_id}"
|
||||
|
||||
queue_manager.purge_queues()
|
||||
|
||||
message = upload_json_and_make_message_body(tenant_id)
|
||||
|
||||
queue_manager.publish_message_to_input_queue(tenant_id=tenant_id, message=message)
|
||||
logger.info(f"Put {message} on {queue_name}.")
|
||||
|
||||
storage = get_s3_storage_from_settings(settings)
|
||||
|
||||
for method_frame, properties, body in queue_manager.channel.consume(queue=queue_name, inactivity_timeout=15):
|
||||
if not body:
|
||||
break
|
||||
response = json.loads(body)
|
||||
logger.info(f"Received {response}")
|
||||
logger.info(f"Message headers: {properties.headers}")
|
||||
queue_manager.channel.basic_ack(method_frame.delivery_tag)
|
||||
tenant_id, dossier_id, file_id = itemgetter("tenantId", "dossierId", "fileId")(response)
|
||||
suffix = message["responseFileExtension"]
|
||||
print(f"{tenant_id}/{dossier_id}/{file_id}.{suffix}")
|
||||
result = storage.get_object(f"{tenant_id}/{dossier_id}/{file_id}.{suffix}")
|
||||
result = json.loads(gzip.decompress(result))
|
||||
logger.info(f"Contents of result on storage: {result}")
|
||||
break
|
||||
queue_manager.stop_consuming()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uuid
|
||||
|
||||
unique_ids = [str(uuid.uuid4()) for _ in range(100)]
|
||||
|
||||
for tenant in unique_ids:
|
||||
send_tenant_event(tenant_id=tenant, event_type="create")
|
||||
|
||||
# for tenant in tenant_ids:
|
||||
# send_service_request(tenant_id=tenant)
|
||||
|
||||
# for tenant in tenant_ids:
|
||||
# send_tenant_event(tenant_id=tenant, event_type="delete")
|
||||
Loading…
x
Reference in New Issue
Block a user