diff --git a/config/settings.toml b/config/settings.toml new file mode 100644 index 0000000..3760d2d --- /dev/null +++ b/config/settings.toml @@ -0,0 +1,40 @@ +[logging] +level = "DEBUG" + +[metrics.prometheus] +enabled = true +prefix = "redactmanager_research_service_parameter" # convention: '{product_name}_{service_name}_{parameter}' +host = "0.0.0.0" +port = 8080 + +[rabbitmq] +host = "localhost" +port = "5672" +username = "user" +password = "bitnami" +heartbeat = 5 +connection_sleep = 5 +write_consumer_token = false +input_queue = "request_queue" +output_queue = "response_queue" +dead_letter_queue = "dead_letter_queue" + +[storage] +type = "s3" + +[storage.s3] +bucket = "redaction" +endpoint = "http://127.0.0.1:9000" +key = "root" +secret = "password" +region = "eu-central-1" + +[storage.azure] +container = "redaction" +connection_string = "DefaultEndpointsProtocol=..." + +[multi_tenancy.server] +public_key = "redaction" +endpoint = "http://tenant-user-management:8081/internal-api/tenants" + + diff --git a/pyinfra/queue/queue_manager.py b/pyinfra/queue/queue_manager.py index b004c4d..c06ac7f 100644 --- a/pyinfra/queue/queue_manager.py +++ b/pyinfra/queue/queue_manager.py @@ -2,14 +2,23 @@ import atexit import concurrent.futures import json import logging +import sys +import threading +import time +from functools import partial +from typing import Union, Callable + import pika import pika.exceptions import signal + +from dynaconf import Dynaconf from kn_utils.logging import logger from pathlib import Path -from pika.adapters.blocking_connection import BlockingChannel +from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection +from retry import retry -from pyinfra.config import Config +from pyinfra.config import Config, load_settings from pyinfra.exception import ProcessingFailure from pyinfra.payload_processing.processor import PayloadProcessor from pyinfra.utils.dict import safe_project @@ -203,3 +212,132 @@ class QueueManager: raise return callback + + +class QueueManagerV2: + def __init__(self, settings: Dynaconf = load_settings()): + self.input_queue = settings.rabbitmq.input_queue + self.output_queue = settings.rabbitmq.output_queue + self.dead_letter_queue = settings.rabbitmq.dead_letter_queue + + self.connection_parameters = self.create_connection_parameters(settings) + + self.connection: Union[BlockingConnection, None] = None + self.channel: Union[BlockingChannel, None] = None + + self.consumer_thread: Union[threading.Thread, None] = None + self.worker_threads: list[threading.Thread] = [] + + atexit.register(self.stop_consuming) + signal.signal(signal.SIGTERM, self._handle_stop_signal) + signal.signal(signal.SIGINT, self._handle_stop_signal) + + @staticmethod + def create_connection_parameters(settings: Dynaconf): + credentials = pika.PlainCredentials(username=settings.rabbitmq.username, password=settings.rabbitmq.password) + pika_connection_params = { + "host": settings.rabbitmq.host, + "port": settings.rabbitmq.port, + "credentials": credentials, + "heartbeat": settings.rabbitmq.heartbeat, + } + + return pika.ConnectionParameters(**pika_connection_params) + + @retry(tries=5, delay=5, jitter=(1, 3)) + def establish_connection(self): + # TODO: set sensible retry parameters + if self.connection and self.connection.is_open: + logger.debug("Connection to RabbitMQ already established.") + return + + logger.info("Establishing connection to RabbitMQ...") + self.connection = pika.BlockingConnection(parameters=self.connection_parameters) + self.channel = self.connection.channel() + self.channel.basic_qos(prefetch_count=1) + + args = { + "x-dead-letter-exchange": "", + "x-dead-letter-routing-key": self.dead_letter_queue, + } + + self.channel.queue_declare(self.input_queue, arguments=args, auto_delete=False, durable=True) + self.channel.queue_declare(self.output_queue, arguments=args, auto_delete=False, durable=True) + logger.info("Connection to RabbitMQ established.") + + def publish_message(self, message: dict, properties: pika.BasicProperties = None): + self.establish_connection() + message_encoded = json.dumps(message).encode("utf-8") + self.channel.basic_publish( + "", + self.input_queue, + properties=properties, + body=message_encoded, + ) + logger.info(f"Published message to queue {self.input_queue}.") + + def get_message(self): + self.establish_connection() + return self.channel.basic_get(self.output_queue) + + def create_on_message_callback(self, callback: Callable): + + def process_message_body_and_await_result(unpacked_message_body): + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor: + logger.debug("Processing payload in separate thread.") + future = thread_pool_executor.submit(callback, unpacked_message_body) + + while future.running(): + logger.debug("Waiting for payload processing to finish...") + self.connection.process_data_events() + self.connection.sleep(5) + + return future.result() + + + def cb(ch, method, properties, body): + logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.") + result = process_message_body_and_await_result(body) + logger.info(f"Processed message with delivery_tag {method.delivery_tag}, publishing result to result-queue.") + ch.basic_publish( + "", + self.output_queue, + result, + ) + + ch.basic_ack(delivery_tag=method.delivery_tag) + logger.info(f"Message with delivery tag {method.delivery_tag} acknowledged.") + + return cb + + def start_consuming(self, message_processor: Callable): + on_message_callback = self.create_on_message_callback(message_processor) + self.establish_connection() + self.channel.basic_consume(self.input_queue, on_message_callback) + try: + self.channel.start_consuming() + except KeyboardInterrupt: + self.stop_consuming() + + def stop_consuming(self): + if self.channel and self.channel.is_open: + logger.info("Stopping consuming...") + self.channel.stop_consuming() + logger.info("Closing channel...") + self.channel.close() + + if self.connection and self.connection.is_open: + logger.info("Closing connection to RabbitMQ...") + self.connection.close() + + logger.info("Waiting for worker threads to finish...") + + for thread in self.worker_threads: + logger.info(f"Stopping worker thread {thread.name}...") + thread.join() + logger.info(f"Worker thread {thread.name} stopped.") + + def _handle_stop_signal(self, signum, *args, **kwargs): + logger.info(f"Received signal {signum}, stopping consuming...") + self.stop_consuming() + sys.exit(0) \ No newline at end of file diff --git a/tests/tests_with_docker_compose/queue_test.py b/tests/tests_with_docker_compose/queue_test.py index d293eef..e48c15f 100644 --- a/tests/tests_with_docker_compose/queue_test.py +++ b/tests/tests_with_docker_compose/queue_test.py @@ -1,38 +1,46 @@ -import gzip import json from multiprocessing import Process from time import sleep from kn_utils.logging import logger +from pyinfra.config import get_config from pyinfra.queue.development_queue_manager import DevelopmentQueueManager -from pyinfra.queue.queue_manager import QueueManager +from pyinfra.queue.queue_manager import QueueManager, QueueManagerV2 +def callback(x): + sleep(4) + response = json.dumps({"status": "success"}).encode("utf-8") + return response + class TestQueueManager: def test_basic_functionality(self, settings): - settings.rabbitmq_heartbeat = 7200 - development_queue_manager = DevelopmentQueueManager(settings) - message = { "targetFilePath": "test/target.json.gz", "responseFilePath": "test/response.json.gz", } - development_queue_manager.publish_request(message) + queue_manager = QueueManagerV2() + # queue_manager_old = QueueManager(get_config()) - queue_manager = QueueManager(settings) + queue_manager.publish_message(message) + queue_manager.publish_message(message) + queue_manager.publish_message(message) + logger.info("Published message") - consume = lambda: queue_manager.start_consuming(lambda x: x) + # consume = lambda: queue_manager.start_consuming(callback) + consume = lambda: queue_manager.start_consuming(callback) p = Process(target=consume) p.start() - wait_time = 1 - logger.info(f"Waiting {wait_time} seconds for the consumer to process the message...") + wait_time = 20 + # logger.info(f"Waiting {wait_time} seconds for the consumer to process the message...") sleep(wait_time) p.kill() - response = development_queue_manager.get_response() + response = queue_manager.get_message() + logger.info(f"Response: {response}") print(response)