343 lines
14 KiB
Python
343 lines
14 KiB
Python
import atexit
|
|
import concurrent.futures
|
|
import json
|
|
import logging
|
|
import sys
|
|
import threading
|
|
import time
|
|
from functools import partial
|
|
from typing import Union, Callable
|
|
|
|
import pika
|
|
import pika.exceptions
|
|
import signal
|
|
|
|
from dynaconf import Dynaconf
|
|
from kn_utils.logging import logger
|
|
from pathlib import Path
|
|
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
|
|
from retry import retry
|
|
|
|
from pyinfra.config import Config, load_settings
|
|
from pyinfra.exception import ProcessingFailure
|
|
from pyinfra.payload_processing.processor import PayloadProcessor
|
|
from pyinfra.utils.dict import safe_project
|
|
|
|
CONFIG = Config()
|
|
|
|
pika_logger = logging.getLogger("pika")
|
|
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
|
|
|
|
|
|
def get_connection_params(config: Config) -> pika.ConnectionParameters:
|
|
"""creates pika connection params from pyinfra.Config class
|
|
|
|
Args:
|
|
config (pyinfra.Config): standard pyinfra config class
|
|
|
|
Returns:
|
|
pika.ConnectionParameters: standard pika connection param object
|
|
"""
|
|
credentials = pika.PlainCredentials(username=config.rabbitmq_username, password=config.rabbitmq_password)
|
|
pika_connection_params = {
|
|
"host": config.rabbitmq_host,
|
|
"port": config.rabbitmq_port,
|
|
"credentials": credentials,
|
|
"heartbeat": config.rabbitmq_heartbeat,
|
|
}
|
|
|
|
return pika.ConnectionParameters(**pika_connection_params)
|
|
|
|
|
|
def _get_n_previous_attempts(props):
|
|
return 0 if props.headers is None else props.headers.get("x-retry-count", 0)
|
|
|
|
|
|
def token_file_name():
|
|
"""create filepath
|
|
|
|
Returns:
|
|
joblib.Path: filepath
|
|
"""
|
|
token_file_path = Path("/tmp") / "consumer_token.txt"
|
|
return token_file_path
|
|
|
|
|
|
class QueueManager:
|
|
"""Handle RabbitMQ message reception & delivery"""
|
|
|
|
def __init__(self, config: Config):
|
|
self._input_queue = config.request_queue
|
|
self._output_queue = config.response_queue
|
|
self._dead_letter_queue = config.dead_letter_queue
|
|
|
|
# controls how often we send out a life signal
|
|
self._heartbeat = config.rabbitmq_heartbeat
|
|
|
|
# controls for how long we only process data events (e.g. heartbeats),
|
|
# while the queue is blocked and we process the given callback function
|
|
self._connection_sleep = config.rabbitmq_connection_sleep
|
|
|
|
self._write_token = config.write_consumer_token == "True"
|
|
self._set_consumer_token(None)
|
|
|
|
self._connection_params = get_connection_params(config)
|
|
self._connection = pika.BlockingConnection(parameters=self._connection_params)
|
|
self._channel: BlockingChannel
|
|
|
|
# necessary to pods can be terminated/restarted in K8s/docker
|
|
atexit.register(self.stop_consuming)
|
|
signal.signal(signal.SIGTERM, self._handle_stop_signal)
|
|
signal.signal(signal.SIGINT, self._handle_stop_signal)
|
|
|
|
def _set_consumer_token(self, token_value):
|
|
self._consumer_token = token_value
|
|
|
|
if self._write_token:
|
|
token_file_path = token_file_name()
|
|
|
|
with token_file_path.open(mode="w", encoding="utf8") as token_file:
|
|
text = token_value if token_value is not None else ""
|
|
token_file.write(text)
|
|
|
|
def _open_channel(self):
|
|
self._channel = self._connection.channel()
|
|
self._channel.basic_qos(prefetch_count=1)
|
|
|
|
args = {
|
|
"x-dead-letter-exchange": "",
|
|
"x-dead-letter-routing-key": self._dead_letter_queue,
|
|
}
|
|
|
|
self._channel.queue_declare(self._input_queue, arguments=args, auto_delete=False, durable=True)
|
|
self._channel.queue_declare(self._output_queue, arguments=args, auto_delete=False, durable=True)
|
|
|
|
def start_consuming(self, process_payload: PayloadProcessor):
|
|
"""consumption handling
|
|
- standard callback handling is enforced through wrapping process_message_callback in _create_queue_callback
|
|
(implements threading to support heartbeats)
|
|
- initially sets consumer token to None
|
|
- tries to
|
|
- open channels
|
|
- set consumer token to basic_consume, passing in the standard callback and input queue name
|
|
- calls pika start_consuming method on the channels
|
|
- catches all Exceptions & stops consuming + closes channels
|
|
|
|
Args:
|
|
process_payload (Callable): function passed to the queue manager, configured by implementing service
|
|
"""
|
|
callback = self._create_queue_callback(process_payload)
|
|
self._set_consumer_token(None)
|
|
|
|
try:
|
|
self._open_channel()
|
|
self._set_consumer_token(self._channel.basic_consume(self._input_queue, callback))
|
|
logger.info(f"Registered with consumer-tag: {self._consumer_token}")
|
|
self._channel.start_consuming()
|
|
|
|
except Exception:
|
|
logger.error(
|
|
"An unexpected exception occurred while consuming messages. Consuming will stop.", exc_info=True
|
|
)
|
|
raise
|
|
|
|
finally:
|
|
self.stop_consuming()
|
|
self._connection.close()
|
|
|
|
def stop_consuming(self):
|
|
if self._consumer_token and self._connection:
|
|
logger.info(f"Cancelling subscription for consumer-tag {self._consumer_token}")
|
|
self._channel.stop_consuming(self._consumer_token)
|
|
self._set_consumer_token(None)
|
|
|
|
def _handle_stop_signal(self, signal_number, _stack_frame, *args, **kwargs):
|
|
logger.info(f"Received signal {signal_number}")
|
|
self.stop_consuming()
|
|
|
|
def _create_queue_callback(self, process_payload: PayloadProcessor):
|
|
def process_message_body_and_await_result(unpacked_message_body):
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
|
logger.debug("Processing payload in separate thread.")
|
|
future = thread_pool_executor.submit(process_payload, unpacked_message_body)
|
|
|
|
while future.running():
|
|
logger.debug("Waiting for payload processing to finish...")
|
|
self._connection.sleep(float(self._connection_sleep))
|
|
|
|
try:
|
|
return future.result()
|
|
except Exception as err:
|
|
raise ProcessingFailure(f"QueueMessagePayload processing failed: {repr(err)}") from err
|
|
|
|
def acknowledge_message_and_publish_response(frame, headers, response_body):
|
|
response_properties = pika.BasicProperties(headers=headers) if headers else None
|
|
self._channel.basic_publish("", self._output_queue, json.dumps(response_body).encode(), response_properties)
|
|
logger.debug(f"Result published, acknowledging incoming message with delivery_tag {frame.delivery_tag}.")
|
|
self._channel.basic_ack(frame.delivery_tag)
|
|
|
|
def callback(_channel, frame, properties, body):
|
|
logger.info(f"Received message from queue with delivery_tag {frame.delivery_tag}.")
|
|
logger.debug(f"Message headers: {properties.headers}")
|
|
|
|
# Only try to process each message once. Re-queueing will be handled by the dead-letter-exchange. This
|
|
# prevents endless retries on messages that are impossible to process.
|
|
if frame.redelivered:
|
|
logger.info(
|
|
f"Aborting message processing for delivery_tag {frame.delivery_tag} due to it being redelivered.",
|
|
)
|
|
self._channel.basic_nack(frame.delivery_tag, requeue=False)
|
|
return
|
|
|
|
try:
|
|
logger.debug(f"Processing {frame}, {properties}, {body}")
|
|
filtered_message_headers = safe_project(properties.headers, ["X-TENANT-ID"])
|
|
message_body = {**json.loads(body), **filtered_message_headers}
|
|
|
|
processing_result = process_message_body_and_await_result(message_body)
|
|
logger.info(
|
|
f"Processed message with delivery_tag {frame.delivery_tag}, publishing result to result-queue."
|
|
)
|
|
acknowledge_message_and_publish_response(frame, filtered_message_headers, processing_result)
|
|
|
|
except ProcessingFailure as err:
|
|
logger.info(f"Processing message with delivery_tag {frame.delivery_tag} failed, declining.")
|
|
logger.exception(err)
|
|
self._channel.basic_nack(frame.delivery_tag, requeue=False)
|
|
|
|
except Exception:
|
|
n_attempts = _get_n_previous_attempts(properties) + 1
|
|
logger.warning(f"Failed to process message, {n_attempts}", exc_info=True)
|
|
self._channel.basic_nack(frame.delivery_tag, requeue=False)
|
|
raise
|
|
|
|
return callback
|
|
|
|
|
|
class QueueManagerV2:
|
|
def __init__(self, settings: Dynaconf = load_settings()):
|
|
self.input_queue = settings.rabbitmq.input_queue
|
|
self.output_queue = settings.rabbitmq.output_queue
|
|
self.dead_letter_queue = settings.rabbitmq.dead_letter_queue
|
|
|
|
self.connection_parameters = self.create_connection_parameters(settings)
|
|
|
|
self.connection: Union[BlockingConnection, None] = None
|
|
self.channel: Union[BlockingChannel, None] = None
|
|
|
|
self.consumer_thread: Union[threading.Thread, None] = None
|
|
self.worker_threads: list[threading.Thread] = []
|
|
|
|
atexit.register(self.stop_consuming)
|
|
signal.signal(signal.SIGTERM, self._handle_stop_signal)
|
|
signal.signal(signal.SIGINT, self._handle_stop_signal)
|
|
|
|
@staticmethod
|
|
def create_connection_parameters(settings: Dynaconf):
|
|
credentials = pika.PlainCredentials(username=settings.rabbitmq.username, password=settings.rabbitmq.password)
|
|
pika_connection_params = {
|
|
"host": settings.rabbitmq.host,
|
|
"port": settings.rabbitmq.port,
|
|
"credentials": credentials,
|
|
"heartbeat": settings.rabbitmq.heartbeat,
|
|
}
|
|
|
|
return pika.ConnectionParameters(**pika_connection_params)
|
|
|
|
@retry(tries=5, delay=5, jitter=(1, 3))
|
|
def establish_connection(self):
|
|
# TODO: set sensible retry parameters
|
|
if self.connection and self.connection.is_open:
|
|
logger.debug("Connection to RabbitMQ already established.")
|
|
return
|
|
|
|
logger.info("Establishing connection to RabbitMQ...")
|
|
self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
|
|
self.channel = self.connection.channel()
|
|
self.channel.basic_qos(prefetch_count=1)
|
|
|
|
args = {
|
|
"x-dead-letter-exchange": "",
|
|
"x-dead-letter-routing-key": self.dead_letter_queue,
|
|
}
|
|
|
|
self.channel.queue_declare(self.input_queue, arguments=args, auto_delete=False, durable=True)
|
|
self.channel.queue_declare(self.output_queue, arguments=args, auto_delete=False, durable=True)
|
|
logger.info("Connection to RabbitMQ established.")
|
|
|
|
def publish_message(self, message: dict, properties: pika.BasicProperties = None):
|
|
self.establish_connection()
|
|
message_encoded = json.dumps(message).encode("utf-8")
|
|
self.channel.basic_publish(
|
|
"",
|
|
self.input_queue,
|
|
properties=properties,
|
|
body=message_encoded,
|
|
)
|
|
logger.info(f"Published message to queue {self.input_queue}.")
|
|
|
|
def get_message(self):
|
|
self.establish_connection()
|
|
return self.channel.basic_get(self.output_queue)
|
|
|
|
def create_on_message_callback(self, callback: Callable):
|
|
|
|
def process_message_body_and_await_result(unpacked_message_body):
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
|
logger.debug("Processing payload in separate thread.")
|
|
future = thread_pool_executor.submit(callback, unpacked_message_body)
|
|
|
|
while future.running():
|
|
logger.debug("Waiting for payload processing to finish...")
|
|
self.connection.process_data_events()
|
|
self.connection.sleep(5)
|
|
|
|
return future.result()
|
|
|
|
|
|
def cb(ch, method, properties, body):
|
|
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
|
|
result = process_message_body_and_await_result(body)
|
|
logger.info(f"Processed message with delivery_tag {method.delivery_tag}, publishing result to result-queue.")
|
|
ch.basic_publish(
|
|
"",
|
|
self.output_queue,
|
|
result,
|
|
)
|
|
|
|
ch.basic_ack(delivery_tag=method.delivery_tag)
|
|
logger.info(f"Message with delivery tag {method.delivery_tag} acknowledged.")
|
|
|
|
return cb
|
|
|
|
def start_consuming(self, message_processor: Callable):
|
|
on_message_callback = self.create_on_message_callback(message_processor)
|
|
self.establish_connection()
|
|
self.channel.basic_consume(self.input_queue, on_message_callback)
|
|
try:
|
|
self.channel.start_consuming()
|
|
except KeyboardInterrupt:
|
|
self.stop_consuming()
|
|
|
|
def stop_consuming(self):
|
|
if self.channel and self.channel.is_open:
|
|
logger.info("Stopping consuming...")
|
|
self.channel.stop_consuming()
|
|
logger.info("Closing channel...")
|
|
self.channel.close()
|
|
|
|
if self.connection and self.connection.is_open:
|
|
logger.info("Closing connection to RabbitMQ...")
|
|
self.connection.close()
|
|
|
|
logger.info("Waiting for worker threads to finish...")
|
|
|
|
for thread in self.worker_threads:
|
|
logger.info(f"Stopping worker thread {thread.name}...")
|
|
thread.join()
|
|
logger.info(f"Worker thread {thread.name} stopped.")
|
|
|
|
def _handle_stop_signal(self, signum, *args, **kwargs):
|
|
logger.info(f"Received signal {signum}, stopping consuming...")
|
|
self.stop_consuming()
|
|
sys.exit(0) |