refactor: finnish queue manager, queue manager tests, also add validation logic, integrate new settings

This commit is contained in:
Julius Unverfehrt 2024-01-16 14:16:27 +01:00
parent b49645cce4
commit ebc519ee0d
7 changed files with 244 additions and 368 deletions

View File

@ -1,40 +0,0 @@
[logging]
level = "DEBUG"
[metrics.prometheus]
enabled = true
prefix = "redactmanager_research_service_parameter" # convention: '{product_name}_{service_name}_{parameter}'
host = "0.0.0.0"
port = 8080
[rabbitmq]
host = "localhost"
port = "5672"
username = "user"
password = "bitnami"
heartbeat = 5
connection_sleep = 5
write_consumer_token = false
input_queue = "request_queue"
output_queue = "response_queue"
dead_letter_queue = "dead_letter_queue"
[storage]
type = "s3"
[storage.s3]
bucket = "redaction"
endpoint = "http://127.0.0.1:9000"
key = "root"
secret = "password"
region = "eu-central-1"
[storage.azure]
container = "redaction"
connection_string = "DefaultEndpointsProtocol=..."
[multi_tenancy.server]
public_key = "redaction"
endpoint = "http://tenant-user-management:8081/internal-api/tenants"

View File

@ -1,6 +1,10 @@
import os
from os import environ
from pathlib import Path
from typing import Union
from dynaconf import Dynaconf
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
@ -45,7 +49,7 @@ class Config:
self.rabbitmq_password = read_from_environment("RABBITMQ_PASSWORD", "bitnami")
# Controls AMQP heartbeat timeout in seconds
self.rabbitmq_heartbeat = int(read_from_environment("RABBITMQ_HEARTBEAT", 60))
self.rabbitmq_heartbeat = int(read_from_environment("RABBITMQ_HEARTBEAT", 1))
# Controls AMQP connection sleep timer in seconds
# important for heartbeat to come through while main function runs on other thread
@ -96,7 +100,9 @@ class Config:
# config for x-tenant-endpoint to receive storage connection information per tenant
self.tenant_decryption_public_key = read_from_environment("TENANT_PUBLIC_KEY", "redaction")
self.tenant_endpoint = read_from_environment("TENANT_ENDPOINT", "http://tenant-user-management:8081/internal-api/tenants")
self.tenant_endpoint = read_from_environment(
"TENANT_ENDPOINT", "http://tenant-user-management:8081/internal-api/tenants"
)
# Value to see if we should write a consumer token to a file
self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False")
@ -104,3 +110,22 @@ class Config:
def get_config() -> Config:
return Config()
def load_settings():
# TODO: Make dynamic, so that the settings.toml file can be loaded from any location
# TODO: add validation
root_path = Path(__file__).resolve().parents[0] # this is pyinfra/
repo_root_path = root_path.parents[0] # this is the root of the repo
os.environ["ROOT_PATH"] = str(root_path)
os.environ["REPO_ROOT_PATH"] = str(repo_root_path)
settings = Dynaconf(
load_dotenv=True,
envvar_prefix=False,
settings_files=[
repo_root_path / "config" / "settings.toml",
],
)
return settings

View File

@ -1,40 +0,0 @@
import json
import pika
import pika.exceptions
from pyinfra.config import Config
from pyinfra.queue.queue_manager import QueueManager
class DevelopmentQueueManager(QueueManager):
"""Extends the queue manger with additional functionality that is needed for tests and scripts,
but not in production, such as publishing messages.
"""
def __init__(self, config: Config):
super().__init__(config)
self._open_channel()
def publish_request(self, message: dict, properties: pika.BasicProperties = None):
message_encoded = json.dumps(message).encode("utf-8")
self._channel.basic_publish(
"",
self._input_queue,
properties=properties,
body=message_encoded,
)
def get_response(self):
return self._channel.basic_get(self._output_queue)
def clear_queues(self):
"""purge input & output queues"""
try:
self._channel.queue_purge(self._input_queue)
self._channel.queue_purge(self._output_queue)
except pika.exceptions.ChannelWrongStateError:
pass
def close_channel(self):
self._channel.close()

View File

@ -2,220 +2,27 @@ import atexit
import concurrent.futures
import json
import logging
import signal
import sys
import threading
import time
from functools import partial
from typing import Union, Callable
import pika
import pika.exceptions
import signal
from dynaconf import Dynaconf
from kn_utils.logging import logger
from pathlib import Path
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
from retry import retry
from pyinfra.config import Config, load_settings
from pyinfra.exception import ProcessingFailure
from pyinfra.payload_processing.processor import PayloadProcessor
from pyinfra.utils.dict import safe_project
CONFIG = Config()
from pyinfra.utils.config_validation import validate_settings, queue_manager_validators
pika_logger = logging.getLogger("pika")
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
def get_connection_params(config: Config) -> pika.ConnectionParameters:
"""creates pika connection params from pyinfra.Config class
Args:
config (pyinfra.Config): standard pyinfra config class
Returns:
pika.ConnectionParameters: standard pika connection param object
"""
credentials = pika.PlainCredentials(username=config.rabbitmq_username, password=config.rabbitmq_password)
pika_connection_params = {
"host": config.rabbitmq_host,
"port": config.rabbitmq_port,
"credentials": credentials,
"heartbeat": config.rabbitmq_heartbeat,
}
return pika.ConnectionParameters(**pika_connection_params)
def _get_n_previous_attempts(props):
return 0 if props.headers is None else props.headers.get("x-retry-count", 0)
def token_file_name():
"""create filepath
Returns:
joblib.Path: filepath
"""
token_file_path = Path("/tmp") / "consumer_token.txt"
return token_file_path
class QueueManager:
"""Handle RabbitMQ message reception & delivery"""
def __init__(self, settings: Dynaconf):
validate_settings(settings, queue_manager_validators)
def __init__(self, config: Config):
self._input_queue = config.request_queue
self._output_queue = config.response_queue
self._dead_letter_queue = config.dead_letter_queue
# controls how often we send out a life signal
self._heartbeat = config.rabbitmq_heartbeat
# controls for how long we only process data events (e.g. heartbeats),
# while the queue is blocked and we process the given callback function
self._connection_sleep = config.rabbitmq_connection_sleep
self._write_token = config.write_consumer_token == "True"
self._set_consumer_token(None)
self._connection_params = get_connection_params(config)
self._connection = pika.BlockingConnection(parameters=self._connection_params)
self._channel: BlockingChannel
# necessary to pods can be terminated/restarted in K8s/docker
atexit.register(self.stop_consuming)
signal.signal(signal.SIGTERM, self._handle_stop_signal)
signal.signal(signal.SIGINT, self._handle_stop_signal)
def _set_consumer_token(self, token_value):
self._consumer_token = token_value
if self._write_token:
token_file_path = token_file_name()
with token_file_path.open(mode="w", encoding="utf8") as token_file:
text = token_value if token_value is not None else ""
token_file.write(text)
def _open_channel(self):
self._channel = self._connection.channel()
self._channel.basic_qos(prefetch_count=1)
args = {
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self._dead_letter_queue,
}
self._channel.queue_declare(self._input_queue, arguments=args, auto_delete=False, durable=True)
self._channel.queue_declare(self._output_queue, arguments=args, auto_delete=False, durable=True)
def start_consuming(self, process_payload: PayloadProcessor):
"""consumption handling
- standard callback handling is enforced through wrapping process_message_callback in _create_queue_callback
(implements threading to support heartbeats)
- initially sets consumer token to None
- tries to
- open channels
- set consumer token to basic_consume, passing in the standard callback and input queue name
- calls pika start_consuming method on the channels
- catches all Exceptions & stops consuming + closes channels
Args:
process_payload (Callable): function passed to the queue manager, configured by implementing service
"""
callback = self._create_queue_callback(process_payload)
self._set_consumer_token(None)
try:
self._open_channel()
self._set_consumer_token(self._channel.basic_consume(self._input_queue, callback))
logger.info(f"Registered with consumer-tag: {self._consumer_token}")
self._channel.start_consuming()
except Exception:
logger.error(
"An unexpected exception occurred while consuming messages. Consuming will stop.", exc_info=True
)
raise
finally:
self.stop_consuming()
self._connection.close()
def stop_consuming(self):
if self._consumer_token and self._connection:
logger.info(f"Cancelling subscription for consumer-tag {self._consumer_token}")
self._channel.stop_consuming(self._consumer_token)
self._set_consumer_token(None)
def _handle_stop_signal(self, signal_number, _stack_frame, *args, **kwargs):
logger.info(f"Received signal {signal_number}")
self.stop_consuming()
def _create_queue_callback(self, process_payload: PayloadProcessor):
def process_message_body_and_await_result(unpacked_message_body):
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.debug("Processing payload in separate thread.")
future = thread_pool_executor.submit(process_payload, unpacked_message_body)
while future.running():
logger.debug("Waiting for payload processing to finish...")
self._connection.sleep(float(self._connection_sleep))
try:
return future.result()
except Exception as err:
raise ProcessingFailure(f"QueueMessagePayload processing failed: {repr(err)}") from err
def acknowledge_message_and_publish_response(frame, headers, response_body):
response_properties = pika.BasicProperties(headers=headers) if headers else None
self._channel.basic_publish("", self._output_queue, json.dumps(response_body).encode(), response_properties)
logger.debug(f"Result published, acknowledging incoming message with delivery_tag {frame.delivery_tag}.")
self._channel.basic_ack(frame.delivery_tag)
def callback(_channel, frame, properties, body):
logger.info(f"Received message from queue with delivery_tag {frame.delivery_tag}.")
logger.debug(f"Message headers: {properties.headers}")
# Only try to process each message once. Re-queueing will be handled by the dead-letter-exchange. This
# prevents endless retries on messages that are impossible to process.
if frame.redelivered:
logger.info(
f"Aborting message processing for delivery_tag {frame.delivery_tag} due to it being redelivered.",
)
self._channel.basic_nack(frame.delivery_tag, requeue=False)
return
try:
logger.debug(f"Processing {frame}, {properties}, {body}")
filtered_message_headers = safe_project(properties.headers, ["X-TENANT-ID"])
message_body = {**json.loads(body), **filtered_message_headers}
processing_result = process_message_body_and_await_result(message_body)
logger.info(
f"Processed message with delivery_tag {frame.delivery_tag}, publishing result to result-queue."
)
acknowledge_message_and_publish_response(frame, filtered_message_headers, processing_result)
except ProcessingFailure as err:
logger.info(f"Processing message with delivery_tag {frame.delivery_tag} failed, declining.")
logger.exception(err)
self._channel.basic_nack(frame.delivery_tag, requeue=False)
except Exception:
n_attempts = _get_n_previous_attempts(properties) + 1
logger.warning(f"Failed to process message, {n_attempts}", exc_info=True)
self._channel.basic_nack(frame.delivery_tag, requeue=False)
raise
return callback
class QueueManagerV2:
def __init__(self, settings: Dynaconf = load_settings()):
self.input_queue = settings.rabbitmq.input_queue
self.output_queue = settings.rabbitmq.output_queue
self.dead_letter_queue = settings.rabbitmq.dead_letter_queue
@ -224,9 +31,7 @@ class QueueManagerV2:
self.connection: Union[BlockingConnection, None] = None
self.channel: Union[BlockingChannel, None] = None
self.consumer_thread: Union[threading.Thread, None] = None
self.worker_threads: list[threading.Thread] = []
self.connection_sleep = settings.rabbitmq.connection_sleep
atexit.register(self.stop_consuming)
signal.signal(signal.SIGTERM, self._handle_stop_signal)
@ -244,7 +49,7 @@ class QueueManagerV2:
return pika.ConnectionParameters(**pika_connection_params)
@retry(tries=5, delay=5, jitter=(1, 3))
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger)
def establish_connection(self):
# TODO: set sensible retry parameters
if self.connection and self.connection.is_open:
@ -253,6 +58,8 @@ class QueueManagerV2:
logger.info("Establishing connection to RabbitMQ...")
self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
logger.debug("Opening channel...")
self.channel = self.connection.channel()
self.channel.basic_qos(prefetch_count=1)
@ -263,60 +70,24 @@ class QueueManagerV2:
self.channel.queue_declare(self.input_queue, arguments=args, auto_delete=False, durable=True)
self.channel.queue_declare(self.output_queue, arguments=args, auto_delete=False, durable=True)
logger.info("Connection to RabbitMQ established.")
def publish_message(self, message: dict, properties: pika.BasicProperties = None):
logger.info("Connection to RabbitMQ established, channel open.")
def is_ready(self):
self.establish_connection()
message_encoded = json.dumps(message).encode("utf-8")
self.channel.basic_publish(
"",
self.input_queue,
properties=properties,
body=message_encoded,
)
logger.info(f"Published message to queue {self.input_queue}.")
def get_message(self):
self.establish_connection()
return self.channel.basic_get(self.output_queue)
def create_on_message_callback(self, callback: Callable):
def process_message_body_and_await_result(unpacked_message_body):
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.debug("Processing payload in separate thread.")
future = thread_pool_executor.submit(callback, unpacked_message_body)
while future.running():
logger.debug("Waiting for payload processing to finish...")
self.connection.process_data_events()
self.connection.sleep(5)
return future.result()
def cb(ch, method, properties, body):
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
result = process_message_body_and_await_result(body)
logger.info(f"Processed message with delivery_tag {method.delivery_tag}, publishing result to result-queue.")
ch.basic_publish(
"",
self.output_queue,
result,
)
ch.basic_ack(delivery_tag=method.delivery_tag)
logger.info(f"Message with delivery tag {method.delivery_tag} acknowledged.")
return cb
return self.channel.is_open
def start_consuming(self, message_processor: Callable):
on_message_callback = self.create_on_message_callback(message_processor)
self.establish_connection()
self.channel.basic_consume(self.input_queue, on_message_callback)
on_message_callback = self._make_on_message_callback(message_processor)
try:
self.establish_connection()
self.channel.basic_consume(self.input_queue, on_message_callback)
self.channel.start_consuming()
except KeyboardInterrupt:
except Exception:
logger.error("An unexpected error occurred while consuming messages. Consuming will stop.", exc_info=True)
raise
finally:
self.stop_consuming()
def stop_consuming(self):
@ -330,14 +101,88 @@ class QueueManagerV2:
logger.info("Closing connection to RabbitMQ...")
self.connection.close()
logger.info("Waiting for worker threads to finish...")
def publish_message_to_input_queue(self, message: Union[str, bytes, dict], properties: pika.BasicProperties = None):
if isinstance(message, str):
message = message.encode("utf-8")
elif isinstance(message, dict):
message = json.dumps(message).encode("utf-8")
for thread in self.worker_threads:
logger.info(f"Stopping worker thread {thread.name}...")
thread.join()
logger.info(f"Worker thread {thread.name} stopped.")
self.establish_connection()
self.channel.basic_publish(
"",
self.input_queue,
properties=properties,
body=message,
)
logger.info(f"Published message to queue {self.input_queue}.")
def purge_queues(self):
self.establish_connection()
try:
self.channel.queue_purge(self.input_queue)
self.channel.queue_purge(self.output_queue)
logger.info("Queues purged.")
except pika.exceptions.ChannelWrongStateError:
pass
def get_message_from_output_queue(self):
self.establish_connection()
return self.channel.basic_get(self.output_queue, auto_ack=True)
def _make_on_message_callback(self, message_processor: Callable):
def process_message_body_and_await_result(unpacked_message_body):
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.debug("Processing payload in separate thread.")
future = thread_pool_executor.submit(message_processor, unpacked_message_body)
while future.running():
logger.debug("Waiting for payload processing to finish...")
self.connection.process_data_events()
self.connection.sleep(self.connection_sleep)
return future.result()
def on_message_callback(channel, method, properties, body):
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
if method.redelivered:
logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.")
channel.basic_nack(method.delivery_tag, requeue=False)
return
if body.decode("utf-8") == "STOP":
logger.info(f"Received stop signal, stopping consuming...")
channel.basic_ack(delivery_tag=method.delivery_tag)
self.stop_consuming()
return
try:
filtered_message_headers = (
{k: v for k, v in properties.headers.items() if k.lower().startswith("x-")}
if properties.headers
else {}
)
logger.debug(f"Processing message with {filtered_message_headers=}.")
result = process_message_body_and_await_result({**json.loads(body), **filtered_message_headers})
channel.basic_publish(
"",
self.output_queue,
result,
properties=pika.BasicProperties(headers=filtered_message_headers),
)
logger.info(f"Published result to queue {self.output_queue}.")
channel.basic_ack(delivery_tag=method.delivery_tag)
logger.debug(f"Message with {method.delivery_tag=} acknowledged.")
except Exception:
logger.warning(f"Failed to process message with {method.delivery_tag=}, declining...", exc_info=True)
channel.basic_nack(method.delivery_tag, requeue=False)
raise
return on_message_callback
def _handle_stop_signal(self, signum, *args, **kwargs):
logger.info(f"Received signal {signum}, stopping consuming...")
self.stop_consuming()
sys.exit(0)
sys.exit(0)

View File

@ -0,0 +1,30 @@
from dynaconf import Validator, Dynaconf, ValidationError
from kn_utils.logging import logger
queue_manager_validators = [
Validator("rabbitmq.host", must_exist=True),
Validator("rabbitmq.port", must_exist=True),
Validator("rabbitmq.username", must_exist=True),
Validator("rabbitmq.password", must_exist=True),
Validator("rabbitmq.heartbeat", must_exist=True),
Validator("rabbitmq.connection_sleep", must_exist=True),
Validator("rabbitmq.input_queue", must_exist=True),
Validator("rabbitmq.output_queue", must_exist=True),
Validator("rabbitmq.dead_letter_queue", must_exist=True),
]
def validate_settings(settings: Dynaconf, validators):
settings_valid = True
for validator in validators:
try:
validator.validate(settings)
except ValidationError as e:
settings_valid = False
logger.warning(e)
if not settings_valid:
raise ValidationError("Settings validation failed.")
logger.info("Settings validated.")

View File

@ -2,13 +2,13 @@ import gzip
import json
import pytest
from pyinfra.config import get_config
from pyinfra.config import get_config, load_settings
from pyinfra.payload_processing.payload import LegacyQueueMessagePayload, QueueMessagePayload
@pytest.fixture(scope="session")
def settings():
return get_config()
return load_settings()
@pytest.fixture

View File

@ -1,46 +1,102 @@
import json
from multiprocessing import Process
from sys import stdout
from time import sleep
import pika
import pytest
from kn_utils.logging import logger
from pyinfra.config import get_config
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
from pyinfra.queue.queue_manager import QueueManager, QueueManagerV2
from pyinfra.queue.queue_manager import QueueManager
logger.remove()
logger.add(sink=stdout, level="DEBUG")
def callback(x):
sleep(4)
response = json.dumps({"status": "success"}).encode("utf-8")
return response
def make_callback(process_time):
def callback(x):
sleep(process_time)
return json.dumps({"status": "success"}).encode("utf-8")
return callback
@pytest.fixture(scope="session")
def queue_manager(settings):
settings.rabbitmq_heartbeat = 10
settings.connection_sleep = 5
queue_manager = QueueManager(settings)
yield queue_manager
@pytest.fixture
def input_message():
return json.dumps({
"targetFilePath": "test/target.json.gz",
"responseFilePath": "test/response.json.gz",
})
@pytest.fixture
def stop_message():
return "STOP"
class TestQueueManager:
def test_basic_functionality(self, settings):
message = {
"targetFilePath": "test/target.json.gz",
"responseFilePath": "test/response.json.gz",
}
def test_processing_of_several_messages(self, queue_manager, input_message, stop_message):
queue_manager.purge_queues()
queue_manager = QueueManagerV2()
# queue_manager_old = QueueManager(get_config())
for _ in range(2):
queue_manager.publish_message_to_input_queue(input_message)
queue_manager.publish_message(message)
queue_manager.publish_message(message)
queue_manager.publish_message(message)
logger.info("Published message")
queue_manager.publish_message_to_input_queue(stop_message)
# consume = lambda: queue_manager.start_consuming(callback)
consume = lambda: queue_manager.start_consuming(callback)
p = Process(target=consume)
p.start()
callback = make_callback(1)
queue_manager.start_consuming(callback)
wait_time = 20
# logger.info(f"Waiting {wait_time} seconds for the consumer to process the message...")
sleep(wait_time)
for _ in range(2):
response = queue_manager.get_message_from_output_queue()
assert response is not None
assert response[2] == b'{"status": "success"}'
print(response)
p.kill()
def test_all_headers_beginning_with_x_are_forwarded(self, queue_manager, input_message, stop_message):
queue_manager.purge_queues()
response = queue_manager.get_message()
properties = pika.BasicProperties(
headers={
"X-TENANT-ID": "redaction",
"X-OTHER-HEADER": "other-header-value",
"x-tenant_id": "tenant-id-value",
"x_should_not_be_forwarded": "should-not-be-forwarded-value",
}
)
logger.info(f"Response: {response}")
queue_manager.publish_message_to_input_queue(input_message, properties=properties)
queue_manager.publish_message_to_input_queue(stop_message)
callback = make_callback(0.2)
queue_manager.start_consuming(callback)
response = queue_manager.get_message_from_output_queue()
print(response)
assert response[2] == b'{"status": "success"}'
assert response[1].headers["X-TENANT-ID"] == "redaction"
assert response[1].headers["X-OTHER-HEADER"] == "other-header-value"
assert response[1].headers["x-tenant_id"] == "tenant-id-value"
assert "x_should_not_be_forwarded" not in response[1].headers
def test_message_processing_does_not_block_heartbeat(self, queue_manager, input_message, stop_message):
queue_manager.purge_queues()
queue_manager.publish_message_to_input_queue(input_message)
queue_manager.publish_message_to_input_queue(stop_message)
callback = make_callback(15)
queue_manager.start_consuming(callback)
response = queue_manager.get_message_from_output_queue()
assert response[2] == b'{"status": "success"}'