pyinfra/pyinfra/queue/queue_manager/pika_queue_manager.py

161 lines
5.1 KiB
Python

import json
import logging
import pika
from pyinfra.config import CONFIG
from pyinfra.exceptions import ProcessingFailure
from pyinfra.queue.queue_manager.queue_manager import QueueHandle, QueueManager
logger = logging.getLogger("pika")
logger.setLevel(logging.WARNING)
logger = logging.getLogger()
def monkey_patch_queue_handle(channel, queue) -> QueueHandle:
empty_message = (None, None, None)
def is_empty_message(message):
return message == empty_message
queue_handle = QueueHandle()
queue_handle.empty = lambda: is_empty_message(channel.basic_get(queue))
def produce_items():
while True:
message = channel.basic_get(queue)
if is_empty_message(message):
break
method_frame, properties, body = message
channel.basic_ack(method_frame.delivery_tag)
yield json.loads(body)
queue_handle.to_list = lambda: list(produce_items())
return queue_handle
def get_connection_params():
credentials = pika.PlainCredentials(username=CONFIG.rabbitmq.user, password=CONFIG.rabbitmq.password)
kwargs = {
"host": CONFIG.rabbitmq.host,
"port": CONFIG.rabbitmq.port,
"credentials": credentials,
"heartbeat": CONFIG.rabbitmq.heartbeat,
}
parameters = pika.ConnectionParameters(**kwargs)
return parameters
def get_n_previous_attempts(props):
return 0 if props.headers is None else props.headers.get("x-retry-count", 0)
def attempts_remain(n_attempts, max_attempts):
return n_attempts < max_attempts
class PikaQueueManager(QueueManager):
def __init__(self, input_queue, output_queue, dead_letter_queue=None, connection_params=None):
super().__init__(input_queue, output_queue)
if not connection_params:
connection_params = get_connection_params()
self.connection = pika.BlockingConnection(parameters=connection_params)
self.channel = self.connection.channel()
self.channel.basic_qos(prefetch_count=1)
if not dead_letter_queue:
dead_letter_queue = CONFIG.rabbitmq.queues.dead_letter
args = {"x-dead-letter-exchange": "", "x-dead-letter-routing-key": dead_letter_queue}
self.channel.queue_declare(input_queue, arguments=args, auto_delete=False, durable=True)
self.channel.queue_declare(output_queue, arguments=args, auto_delete=False, durable=True)
def republish(self, body, n_current_attempts, frame):
self.channel.basic_publish(
exchange="",
routing_key=self._input_queue,
body=body,
properties=pika.BasicProperties(headers={"x-retry-count": n_current_attempts}),
)
self.channel.basic_ack(delivery_tag=frame.delivery_tag)
def publish_request(self, request):
logger.debug(f"Publishing {request}")
self.channel.basic_publish("", self._input_queue, json.dumps(request).encode())
def reject(self, body, frame):
logger.exception(f"Adding to dead letter queue: {body}")
self.channel.basic_reject(delivery_tag=frame.delivery_tag, requeue=False)
def publish_response(self, message, callback, max_attempts=3):
logger.debug(f"Processing {message}.")
frame, properties, body = message
n_attempts = get_n_previous_attempts(properties) + 1
try:
response = json.dumps(callback(json.loads(body)))
self.channel.basic_publish("", self._output_queue, response.encode())
self.channel.basic_ack(frame.delivery_tag)
except ProcessingFailure:
logger.error(f"Message failed to process {n_attempts}/{max_attempts} times: {body}")
if attempts_remain(n_attempts, max_attempts):
self.republish(body, n_attempts, frame)
else:
self.reject(body, frame)
def pull_request(self):
return self.channel.basic_get(self._input_queue)
def consume(self, inactivity_timeout=None):
logger.debug("Consuming")
return self.channel.consume(self._input_queue, inactivity_timeout=inactivity_timeout)
def consume_and_publish(self, visitor):
logger.info(f"Consuming input queue.")
for message in self.consume():
self.publish_response(message, visitor)
def basic_consume_and_publish(self, visitor):
logger.info(f"Basic consuming input queue.")
def callback(channel, frame, properties, body):
message = (frame, properties, body)
return self.publish_response(message, visitor)
self.channel.basic_consume(self._input_queue, callback)
self.channel.start_consuming()
def clear(self):
try:
self.channel.queue_purge(self._input_queue)
self.channel.queue_purge(self._output_queue)
except pika.exceptions.ChannelWrongStateError:
pass
@property
def input_queue(self) -> QueueHandle:
return monkey_patch_queue_handle(self.channel, self._input_queue)
@property
def output_queue(self) -> QueueHandle:
return monkey_patch_queue_handle(self.channel, self._output_queue)