212 lines
8.6 KiB
Python
212 lines
8.6 KiB
Python
import atexit
|
|
import concurrent.futures
|
|
import json
|
|
import logging
|
|
import signal
|
|
import sys
|
|
from typing import Callable, Union
|
|
|
|
import pika
|
|
import pika.exceptions
|
|
from dynaconf import Dynaconf
|
|
from kn_utils.logging import logger
|
|
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
|
|
from retry import retry
|
|
|
|
from pyinfra.config.loader import validate_settings
|
|
from pyinfra.config.validators import queue_manager_validators
|
|
|
|
pika_logger = logging.getLogger("pika")
|
|
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
|
|
|
|
MessageProcessor = Callable[[dict], dict]
|
|
|
|
|
|
class QueueManager:
|
|
def __init__(self, settings: Dynaconf):
|
|
validate_settings(settings, queue_manager_validators)
|
|
|
|
self.input_queue = settings.rabbitmq.input_queue
|
|
self.output_queue = settings.rabbitmq.output_queue
|
|
self.dead_letter_queue = settings.rabbitmq.dead_letter_queue
|
|
|
|
self.connection_parameters = self.create_connection_parameters(settings)
|
|
|
|
self.connection: Union[BlockingConnection, None] = None
|
|
self.channel: Union[BlockingChannel, None] = None
|
|
self.connection_sleep = settings.rabbitmq.connection_sleep
|
|
self.processing_callback = False
|
|
self.received_signal = False
|
|
|
|
atexit.register(self.stop_consuming)
|
|
signal.signal(signal.SIGTERM, self._handle_stop_signal)
|
|
signal.signal(signal.SIGINT, self._handle_stop_signal)
|
|
|
|
@staticmethod
|
|
def create_connection_parameters(settings: Dynaconf):
|
|
credentials = pika.PlainCredentials(username=settings.rabbitmq.username, password=settings.rabbitmq.password)
|
|
pika_connection_params = {
|
|
"host": settings.rabbitmq.host,
|
|
"port": settings.rabbitmq.port,
|
|
"credentials": credentials,
|
|
"heartbeat": settings.rabbitmq.heartbeat,
|
|
}
|
|
|
|
return pika.ConnectionParameters(**pika_connection_params)
|
|
|
|
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger)
|
|
def establish_connection(self):
|
|
# TODO: set sensible retry parameters
|
|
if self.connection and self.connection.is_open:
|
|
logger.debug("Connection to RabbitMQ already established.")
|
|
return
|
|
|
|
logger.info("Establishing connection to RabbitMQ...")
|
|
self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
|
|
|
|
logger.debug("Opening channel...")
|
|
self.channel = self.connection.channel()
|
|
self.channel.basic_qos(prefetch_count=1)
|
|
|
|
args = {
|
|
"x-dead-letter-exchange": "",
|
|
"x-dead-letter-routing-key": self.dead_letter_queue,
|
|
}
|
|
|
|
self.channel.queue_declare(self.input_queue, arguments=args, auto_delete=False, durable=True)
|
|
self.channel.queue_declare(self.output_queue, arguments=args, auto_delete=False, durable=True)
|
|
|
|
logger.info("Connection to RabbitMQ established, channel open.")
|
|
|
|
def is_ready(self):
|
|
self.establish_connection()
|
|
return self.channel.is_open
|
|
|
|
@retry(exceptions=pika.exceptions.AMQPConnectionError, tries=3, delay=5, jitter=(1, 3), logger=logger)
|
|
def start_consuming(self, message_processor: Callable):
|
|
on_message_callback = self._make_on_message_callback(message_processor)
|
|
|
|
try:
|
|
self.establish_connection()
|
|
self.channel.basic_consume(self.input_queue, on_message_callback)
|
|
self.channel.start_consuming()
|
|
except Exception:
|
|
logger.error("An unexpected error occurred while consuming messages. Consuming will stop.", exc_info=True)
|
|
raise
|
|
finally:
|
|
self.stop_consuming()
|
|
|
|
def stop_consuming(self):
|
|
if self.channel and self.channel.is_open:
|
|
logger.info("Stopping consuming...")
|
|
self.channel.stop_consuming()
|
|
logger.info("Closing channel...")
|
|
self.channel.close()
|
|
|
|
if self.connection and self.connection.is_open:
|
|
logger.info("Closing connection to RabbitMQ...")
|
|
self.connection.close()
|
|
|
|
def publish_message_to_input_queue(self, message: Union[str, bytes, dict], properties: pika.BasicProperties = None):
|
|
if isinstance(message, str):
|
|
message = message.encode("utf-8")
|
|
elif isinstance(message, dict):
|
|
message = json.dumps(message).encode("utf-8")
|
|
|
|
self.establish_connection()
|
|
self.channel.basic_publish(
|
|
"",
|
|
self.input_queue,
|
|
properties=properties,
|
|
body=message,
|
|
)
|
|
logger.info(f"Published message to queue {self.input_queue}.")
|
|
|
|
def purge_queues(self):
|
|
self.establish_connection()
|
|
try:
|
|
self.channel.queue_purge(self.input_queue)
|
|
self.channel.queue_purge(self.output_queue)
|
|
logger.info("Queues purged.")
|
|
except pika.exceptions.ChannelWrongStateError:
|
|
pass
|
|
|
|
def get_message_from_output_queue(self):
|
|
self.establish_connection()
|
|
return self.channel.basic_get(self.output_queue, auto_ack=True)
|
|
|
|
def _make_on_message_callback(self, message_processor: MessageProcessor):
|
|
def process_message_body_and_await_result(unpacked_message_body):
|
|
# Processing the message in a separate thread is necessary for the main thread pika client to be able to
|
|
# process data events (e.g. heartbeats) while the message is being processed.
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
|
logger.info("Processing payload in separate thread.")
|
|
future = thread_pool_executor.submit(message_processor, unpacked_message_body)
|
|
|
|
# TODO: This block is probably not necessary, but kept since the implications of removing it are
|
|
# unclear. Remove it in a future iteration where less changes are being made to the code base.
|
|
while future.running():
|
|
logger.debug("Waiting for payload processing to finish...")
|
|
self.connection.sleep(self.connection_sleep)
|
|
|
|
return future.result()
|
|
|
|
def on_message_callback(channel, method, properties, body):
|
|
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
|
|
self.processing_callback = True
|
|
|
|
if method.redelivered:
|
|
logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.")
|
|
channel.basic_nack(method.delivery_tag, requeue=False)
|
|
return
|
|
|
|
if body.decode("utf-8") == "STOP":
|
|
logger.info(f"Received stop signal, stopping consuming...")
|
|
channel.basic_ack(delivery_tag=method.delivery_tag)
|
|
self.stop_consuming()
|
|
return
|
|
|
|
try:
|
|
filtered_message_headers = (
|
|
{k: v for k, v in properties.headers.items() if k.lower().startswith("x-")}
|
|
if properties.headers
|
|
else {}
|
|
)
|
|
logger.debug(f"Processing message with {filtered_message_headers=}.")
|
|
result: dict = (
|
|
process_message_body_and_await_result({**json.loads(body), **filtered_message_headers}) or {}
|
|
)
|
|
|
|
channel.basic_publish(
|
|
"",
|
|
self.output_queue,
|
|
json.dumps(result).encode(),
|
|
properties=pika.BasicProperties(headers=filtered_message_headers),
|
|
)
|
|
logger.info(f"Published result to queue {self.output_queue}.")
|
|
|
|
channel.basic_ack(delivery_tag=method.delivery_tag)
|
|
logger.debug(f"Message with {method.delivery_tag=} acknowledged.")
|
|
except FileNotFoundError as e:
|
|
logger.warning(f"{e}, declining message with {method.delivery_tag=}.")
|
|
channel.basic_nack(method.delivery_tag, requeue=False)
|
|
except Exception:
|
|
logger.warning(f"Failed to process message with {method.delivery_tag=}, declining...", exc_info=True)
|
|
channel.basic_nack(method.delivery_tag, requeue=False)
|
|
raise
|
|
|
|
finally:
|
|
self.processing_callback = False
|
|
if self.received_signal:
|
|
self.stop_consuming()
|
|
sys.exit(0)
|
|
|
|
return on_message_callback
|
|
|
|
def _handle_stop_signal(self, signum, *args, **kwargs):
|
|
logger.info(f"Received signal {signum}, stopping consuming...")
|
|
self.received_signal = True
|
|
if not self.processing_callback:
|
|
self.stop_consuming()
|
|
sys.exit(0)
|