pyinfra/pyinfra/queue/queue_manager.py

128 lines
4.7 KiB
Python

import atexit
import json
import logging
import signal
from typing import Callable
import pika
import pika.exceptions
from pyinfra.config import Config
pika_logger = logging.getLogger("pika")
pika_logger.setLevel(logging.WARNING)
def get_connection_params(config: Config) -> pika.ConnectionParameters:
credentials = pika.PlainCredentials(username=config.rabbitmq_username, password=config.rabbitmq_password)
pika_connection_params = {
"host": config.rabbitmq_host,
"port": config.rabbitmq_port,
"credentials": credentials,
"heartbeat": int(config.rabbitmq_heartbeat),
}
return pika.ConnectionParameters(**pika_connection_params)
def _get_n_previous_attempts(props):
return 0 if props.headers is None else props.headers.get("x-retry-count", 0)
class QueueManager(object):
def __init__(self, config: Config):
self.logger = logging.getLogger("queue_manager")
self.logger.setLevel(config.logging_level_root)
self._consumer_token = None
self._connection_params = get_connection_params(config)
self._input_queue = config.request_queue
self._output_queue = config.response_queue
self._dead_letter_queue = config.dead_letter_queue
atexit.register(self.stop_consuming)
signal.signal(signal.SIGTERM, self._handle_stop_signal)
signal.signal(signal.SIGINT, self._handle_stop_signal)
def _open_channel(self):
self._connection = pika.BlockingConnection(parameters=self._connection_params)
self._channel = self._connection.channel()
self._channel.basic_qos(prefetch_count=1)
args = {"x-dead-letter-exchange": "", "x-dead-letter-routing-key": self._dead_letter_queue}
self._channel.queue_declare(self._input_queue, arguments=args, auto_delete=False, durable=True)
self._channel.queue_declare(self._output_queue, arguments=args, auto_delete=False, durable=True)
def _close_channel(self):
self._channel.close()
self._connection.close()
def start_consuming(self, process_message_callback: Callable):
self._open_channel()
callback = self._create_queue_callback(process_message_callback)
self.logger.info("Consuming from queue")
self._consumer_token = None
try:
self._consumer_token = self._channel.basic_consume(self._input_queue, callback)
self.logger.info(f"Registered with consumer-tag: {self._consumer_token}")
self._channel.start_consuming()
except Exception:
self.logger.warning(
f"An unexpected exception occurred while consuming messages. Consuming will stop."
)
raise
finally:
self.stop_consuming()
def stop_consuming(self):
if self._consumer_token and self._connection:
self.logger.info(f"Cancelling subscription for consumer-tag: {self._consumer_token}")
self._channel.basic_cancel(self._consumer_token)
self._consumer_token = None
self._close_channel()
def _handle_stop_signal(self, signal_number, _stack_frame, *args, **kwargs):
self.logger.info(f"Received signal {signal_number}")
self.stop_consuming()
def _create_queue_callback(self, process_message_callback: Callable):
def callback(_channel, frame, properties, body):
self.logger.info(f"Received message from queue with delivery_tag {frame.delivery_tag}")
self.logger.debug(f"Processing {(frame, properties, body)}.")
try:
unpacked_message_body = json.loads(body)
callback_result = process_message_callback(unpacked_message_body)
self.logger.info("Processed message, publishing result to result-queue")
self._channel.basic_publish("", self._output_queue, json.dumps(callback_result).encode())
self.logger.info(
f"Result published, acknowledging incoming message with delivery_tag {frame.delivery_tag}"
)
self._channel.basic_ack(frame.delivery_tag)
self.logger.info(f"Message with delivery_tag {frame.delivery_tag} processed")
except Exception as ex:
n_attempts = _get_n_previous_attempts(properties) + 1
self.logger.warning(f"Failed to process message, {n_attempts} attempts, error: {str(ex)}")
self._channel.basic_nack(frame.delivery_tag, requeue=False)
raise ex
return callback
def clear(self):
try:
self._channel.queue_purge(self._input_queue)
self._channel.queue_purge(self._output_queue)
except pika.exceptions.ChannelWrongStateError:
pass