pyinfra/pyinfra/queue/queue_manager.py
2024-01-15 16:46:33 +01:00

343 lines
14 KiB
Python

import atexit
import concurrent.futures
import json
import logging
import sys
import threading
import time
from functools import partial
from typing import Union, Callable
import pika
import pika.exceptions
import signal
from dynaconf import Dynaconf
from kn_utils.logging import logger
from pathlib import Path
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
from retry import retry
from pyinfra.config import Config, load_settings
from pyinfra.exception import ProcessingFailure
from pyinfra.payload_processing.processor import PayloadProcessor
from pyinfra.utils.dict import safe_project
CONFIG = Config()
pika_logger = logging.getLogger("pika")
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
def get_connection_params(config: Config) -> pika.ConnectionParameters:
"""creates pika connection params from pyinfra.Config class
Args:
config (pyinfra.Config): standard pyinfra config class
Returns:
pika.ConnectionParameters: standard pika connection param object
"""
credentials = pika.PlainCredentials(username=config.rabbitmq_username, password=config.rabbitmq_password)
pika_connection_params = {
"host": config.rabbitmq_host,
"port": config.rabbitmq_port,
"credentials": credentials,
"heartbeat": config.rabbitmq_heartbeat,
}
return pika.ConnectionParameters(**pika_connection_params)
def _get_n_previous_attempts(props):
return 0 if props.headers is None else props.headers.get("x-retry-count", 0)
def token_file_name():
"""create filepath
Returns:
joblib.Path: filepath
"""
token_file_path = Path("/tmp") / "consumer_token.txt"
return token_file_path
class QueueManager:
"""Handle RabbitMQ message reception & delivery"""
def __init__(self, config: Config):
self._input_queue = config.request_queue
self._output_queue = config.response_queue
self._dead_letter_queue = config.dead_letter_queue
# controls how often we send out a life signal
self._heartbeat = config.rabbitmq_heartbeat
# controls for how long we only process data events (e.g. heartbeats),
# while the queue is blocked and we process the given callback function
self._connection_sleep = config.rabbitmq_connection_sleep
self._write_token = config.write_consumer_token == "True"
self._set_consumer_token(None)
self._connection_params = get_connection_params(config)
self._connection = pika.BlockingConnection(parameters=self._connection_params)
self._channel: BlockingChannel
# necessary to pods can be terminated/restarted in K8s/docker
atexit.register(self.stop_consuming)
signal.signal(signal.SIGTERM, self._handle_stop_signal)
signal.signal(signal.SIGINT, self._handle_stop_signal)
def _set_consumer_token(self, token_value):
self._consumer_token = token_value
if self._write_token:
token_file_path = token_file_name()
with token_file_path.open(mode="w", encoding="utf8") as token_file:
text = token_value if token_value is not None else ""
token_file.write(text)
def _open_channel(self):
self._channel = self._connection.channel()
self._channel.basic_qos(prefetch_count=1)
args = {
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self._dead_letter_queue,
}
self._channel.queue_declare(self._input_queue, arguments=args, auto_delete=False, durable=True)
self._channel.queue_declare(self._output_queue, arguments=args, auto_delete=False, durable=True)
def start_consuming(self, process_payload: PayloadProcessor):
"""consumption handling
- standard callback handling is enforced through wrapping process_message_callback in _create_queue_callback
(implements threading to support heartbeats)
- initially sets consumer token to None
- tries to
- open channels
- set consumer token to basic_consume, passing in the standard callback and input queue name
- calls pika start_consuming method on the channels
- catches all Exceptions & stops consuming + closes channels
Args:
process_payload (Callable): function passed to the queue manager, configured by implementing service
"""
callback = self._create_queue_callback(process_payload)
self._set_consumer_token(None)
try:
self._open_channel()
self._set_consumer_token(self._channel.basic_consume(self._input_queue, callback))
logger.info(f"Registered with consumer-tag: {self._consumer_token}")
self._channel.start_consuming()
except Exception:
logger.error(
"An unexpected exception occurred while consuming messages. Consuming will stop.", exc_info=True
)
raise
finally:
self.stop_consuming()
self._connection.close()
def stop_consuming(self):
if self._consumer_token and self._connection:
logger.info(f"Cancelling subscription for consumer-tag {self._consumer_token}")
self._channel.stop_consuming(self._consumer_token)
self._set_consumer_token(None)
def _handle_stop_signal(self, signal_number, _stack_frame, *args, **kwargs):
logger.info(f"Received signal {signal_number}")
self.stop_consuming()
def _create_queue_callback(self, process_payload: PayloadProcessor):
def process_message_body_and_await_result(unpacked_message_body):
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.debug("Processing payload in separate thread.")
future = thread_pool_executor.submit(process_payload, unpacked_message_body)
while future.running():
logger.debug("Waiting for payload processing to finish...")
self._connection.sleep(float(self._connection_sleep))
try:
return future.result()
except Exception as err:
raise ProcessingFailure(f"QueueMessagePayload processing failed: {repr(err)}") from err
def acknowledge_message_and_publish_response(frame, headers, response_body):
response_properties = pika.BasicProperties(headers=headers) if headers else None
self._channel.basic_publish("", self._output_queue, json.dumps(response_body).encode(), response_properties)
logger.debug(f"Result published, acknowledging incoming message with delivery_tag {frame.delivery_tag}.")
self._channel.basic_ack(frame.delivery_tag)
def callback(_channel, frame, properties, body):
logger.info(f"Received message from queue with delivery_tag {frame.delivery_tag}.")
logger.debug(f"Message headers: {properties.headers}")
# Only try to process each message once. Re-queueing will be handled by the dead-letter-exchange. This
# prevents endless retries on messages that are impossible to process.
if frame.redelivered:
logger.info(
f"Aborting message processing for delivery_tag {frame.delivery_tag} due to it being redelivered.",
)
self._channel.basic_nack(frame.delivery_tag, requeue=False)
return
try:
logger.debug(f"Processing {frame}, {properties}, {body}")
filtered_message_headers = safe_project(properties.headers, ["X-TENANT-ID"])
message_body = {**json.loads(body), **filtered_message_headers}
processing_result = process_message_body_and_await_result(message_body)
logger.info(
f"Processed message with delivery_tag {frame.delivery_tag}, publishing result to result-queue."
)
acknowledge_message_and_publish_response(frame, filtered_message_headers, processing_result)
except ProcessingFailure as err:
logger.info(f"Processing message with delivery_tag {frame.delivery_tag} failed, declining.")
logger.exception(err)
self._channel.basic_nack(frame.delivery_tag, requeue=False)
except Exception:
n_attempts = _get_n_previous_attempts(properties) + 1
logger.warning(f"Failed to process message, {n_attempts}", exc_info=True)
self._channel.basic_nack(frame.delivery_tag, requeue=False)
raise
return callback
class QueueManagerV2:
def __init__(self, settings: Dynaconf = load_settings()):
self.input_queue = settings.rabbitmq.input_queue
self.output_queue = settings.rabbitmq.output_queue
self.dead_letter_queue = settings.rabbitmq.dead_letter_queue
self.connection_parameters = self.create_connection_parameters(settings)
self.connection: Union[BlockingConnection, None] = None
self.channel: Union[BlockingChannel, None] = None
self.consumer_thread: Union[threading.Thread, None] = None
self.worker_threads: list[threading.Thread] = []
atexit.register(self.stop_consuming)
signal.signal(signal.SIGTERM, self._handle_stop_signal)
signal.signal(signal.SIGINT, self._handle_stop_signal)
@staticmethod
def create_connection_parameters(settings: Dynaconf):
credentials = pika.PlainCredentials(username=settings.rabbitmq.username, password=settings.rabbitmq.password)
pika_connection_params = {
"host": settings.rabbitmq.host,
"port": settings.rabbitmq.port,
"credentials": credentials,
"heartbeat": settings.rabbitmq.heartbeat,
}
return pika.ConnectionParameters(**pika_connection_params)
@retry(tries=5, delay=5, jitter=(1, 3))
def establish_connection(self):
# TODO: set sensible retry parameters
if self.connection and self.connection.is_open:
logger.debug("Connection to RabbitMQ already established.")
return
logger.info("Establishing connection to RabbitMQ...")
self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
self.channel = self.connection.channel()
self.channel.basic_qos(prefetch_count=1)
args = {
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.dead_letter_queue,
}
self.channel.queue_declare(self.input_queue, arguments=args, auto_delete=False, durable=True)
self.channel.queue_declare(self.output_queue, arguments=args, auto_delete=False, durable=True)
logger.info("Connection to RabbitMQ established.")
def publish_message(self, message: dict, properties: pika.BasicProperties = None):
self.establish_connection()
message_encoded = json.dumps(message).encode("utf-8")
self.channel.basic_publish(
"",
self.input_queue,
properties=properties,
body=message_encoded,
)
logger.info(f"Published message to queue {self.input_queue}.")
def get_message(self):
self.establish_connection()
return self.channel.basic_get(self.output_queue)
def create_on_message_callback(self, callback: Callable):
def process_message_body_and_await_result(unpacked_message_body):
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.debug("Processing payload in separate thread.")
future = thread_pool_executor.submit(callback, unpacked_message_body)
while future.running():
logger.debug("Waiting for payload processing to finish...")
self.connection.process_data_events()
self.connection.sleep(5)
return future.result()
def cb(ch, method, properties, body):
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
result = process_message_body_and_await_result(body)
logger.info(f"Processed message with delivery_tag {method.delivery_tag}, publishing result to result-queue.")
ch.basic_publish(
"",
self.output_queue,
result,
)
ch.basic_ack(delivery_tag=method.delivery_tag)
logger.info(f"Message with delivery tag {method.delivery_tag} acknowledged.")
return cb
def start_consuming(self, message_processor: Callable):
on_message_callback = self.create_on_message_callback(message_processor)
self.establish_connection()
self.channel.basic_consume(self.input_queue, on_message_callback)
try:
self.channel.start_consuming()
except KeyboardInterrupt:
self.stop_consuming()
def stop_consuming(self):
if self.channel and self.channel.is_open:
logger.info("Stopping consuming...")
self.channel.stop_consuming()
logger.info("Closing channel...")
self.channel.close()
if self.connection and self.connection.is_open:
logger.info("Closing connection to RabbitMQ...")
self.connection.close()
logger.info("Waiting for worker threads to finish...")
for thread in self.worker_threads:
logger.info(f"Stopping worker thread {thread.name}...")
thread.join()
logger.info(f"Worker thread {thread.name} stopped.")
def _handle_stop_signal(self, signum, *args, **kwargs):
logger.info(f"Received signal {signum}, stopping consuming...")
self.stop_consuming()
sys.exit(0)