diff --git a/pyinfra/config.py b/pyinfra/config.py index 3ebcd4f..a4f2c4e 100644 --- a/pyinfra/config.py +++ b/pyinfra/config.py @@ -25,6 +25,10 @@ class Config(object): # Controls AMQP heartbeat timeout in seconds self.rabbitmq_heartbeat = read_from_environment("RABBITMQ_HEARTBEAT", "60") + # Controls AMQP connection sleep timer in seconds + # important for heartbeat to come through while main function runs on other thread + self.rabbitmq_connection_sleep = read_from_environment("RABBITMQ_CONNECTION_SLEEP", 5) + # Queue name for requests to the service self.request_queue = read_from_environment("REQUEST_QUEUE", "request_queue") diff --git a/pyinfra/queue/queue_manager.py b/pyinfra/queue/queue_manager.py index cb30d52..c091286 100644 --- a/pyinfra/queue/queue_manager.py +++ b/pyinfra/queue/queue_manager.py @@ -4,6 +4,7 @@ import logging import signal from typing import Callable from pathlib import Path +import concurrent.futures import pika import pika.exceptions @@ -46,6 +47,10 @@ class QueueManager(object): self._connection_params = get_connection_params(config) + # controls for how long we only process data events (e.g. heartbeats), + # while the queue is blocked and we process the given callback function + self._connection_sleep = config.rabbitmq_connection_sleep + self._input_queue = config.request_queue self._output_queue = config.response_queue self._dead_letter_queue = config.dead_letter_queue @@ -110,6 +115,18 @@ class QueueManager(object): self.stop_consuming() def _create_queue_callback(self, process_message_callback: Callable): + def process_message_body_and_await_result(unpacked_message_body): + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor: + self.logger.debug("opening thread for callback") + future = thread_pool_executor.submit(process_message_callback, unpacked_message_body) + + while future.running(): + self.logger.debug("callback running in thread, processing data events in the meantime") + self._connection.sleep(float(self._connection_sleep)) + + self.logger.debug("fetching result from callback") + return future.result() + def callback(_channel, frame, properties, body): self.logger.info(f"Received message from queue with delivery_tag {frame.delivery_tag}") @@ -126,7 +143,7 @@ class QueueManager(object): try: unpacked_message_body = json.loads(body) - should_publish_result, callback_result = process_message_callback(unpacked_message_body) + should_publish_result, callback_result = process_message_body_and_await_result(unpacked_message_body) if should_publish_result: self.logger.info(f"Processed message with delivery_tag {frame.delivery_tag}, "