refactor: finnish queue manager, queue manager tests, also add validation logic, integrate new settings
This commit is contained in:
parent
b49645cce4
commit
ebc519ee0d
@ -1,40 +0,0 @@
|
||||
[logging]
|
||||
level = "DEBUG"
|
||||
|
||||
[metrics.prometheus]
|
||||
enabled = true
|
||||
prefix = "redactmanager_research_service_parameter" # convention: '{product_name}_{service_name}_{parameter}'
|
||||
host = "0.0.0.0"
|
||||
port = 8080
|
||||
|
||||
[rabbitmq]
|
||||
host = "localhost"
|
||||
port = "5672"
|
||||
username = "user"
|
||||
password = "bitnami"
|
||||
heartbeat = 5
|
||||
connection_sleep = 5
|
||||
write_consumer_token = false
|
||||
input_queue = "request_queue"
|
||||
output_queue = "response_queue"
|
||||
dead_letter_queue = "dead_letter_queue"
|
||||
|
||||
[storage]
|
||||
type = "s3"
|
||||
|
||||
[storage.s3]
|
||||
bucket = "redaction"
|
||||
endpoint = "http://127.0.0.1:9000"
|
||||
key = "root"
|
||||
secret = "password"
|
||||
region = "eu-central-1"
|
||||
|
||||
[storage.azure]
|
||||
container = "redaction"
|
||||
connection_string = "DefaultEndpointsProtocol=..."
|
||||
|
||||
[multi_tenancy.server]
|
||||
public_key = "redaction"
|
||||
endpoint = "http://tenant-user-management:8081/internal-api/tenants"
|
||||
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
import os
|
||||
from os import environ
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from dynaconf import Dynaconf
|
||||
|
||||
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
|
||||
|
||||
|
||||
@ -45,7 +49,7 @@ class Config:
|
||||
self.rabbitmq_password = read_from_environment("RABBITMQ_PASSWORD", "bitnami")
|
||||
|
||||
# Controls AMQP heartbeat timeout in seconds
|
||||
self.rabbitmq_heartbeat = int(read_from_environment("RABBITMQ_HEARTBEAT", 60))
|
||||
self.rabbitmq_heartbeat = int(read_from_environment("RABBITMQ_HEARTBEAT", 1))
|
||||
|
||||
# Controls AMQP connection sleep timer in seconds
|
||||
# important for heartbeat to come through while main function runs on other thread
|
||||
@ -96,7 +100,9 @@ class Config:
|
||||
|
||||
# config for x-tenant-endpoint to receive storage connection information per tenant
|
||||
self.tenant_decryption_public_key = read_from_environment("TENANT_PUBLIC_KEY", "redaction")
|
||||
self.tenant_endpoint = read_from_environment("TENANT_ENDPOINT", "http://tenant-user-management:8081/internal-api/tenants")
|
||||
self.tenant_endpoint = read_from_environment(
|
||||
"TENANT_ENDPOINT", "http://tenant-user-management:8081/internal-api/tenants"
|
||||
)
|
||||
|
||||
# Value to see if we should write a consumer token to a file
|
||||
self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False")
|
||||
@ -104,3 +110,22 @@ class Config:
|
||||
|
||||
def get_config() -> Config:
|
||||
return Config()
|
||||
|
||||
|
||||
def load_settings():
|
||||
# TODO: Make dynamic, so that the settings.toml file can be loaded from any location
|
||||
# TODO: add validation
|
||||
root_path = Path(__file__).resolve().parents[0] # this is pyinfra/
|
||||
repo_root_path = root_path.parents[0] # this is the root of the repo
|
||||
os.environ["ROOT_PATH"] = str(root_path)
|
||||
os.environ["REPO_ROOT_PATH"] = str(repo_root_path)
|
||||
|
||||
settings = Dynaconf(
|
||||
load_dotenv=True,
|
||||
envvar_prefix=False,
|
||||
settings_files=[
|
||||
repo_root_path / "config" / "settings.toml",
|
||||
],
|
||||
)
|
||||
|
||||
return settings
|
||||
|
||||
@ -1,40 +0,0 @@
|
||||
import json
|
||||
|
||||
import pika
|
||||
import pika.exceptions
|
||||
|
||||
from pyinfra.config import Config
|
||||
from pyinfra.queue.queue_manager import QueueManager
|
||||
|
||||
|
||||
class DevelopmentQueueManager(QueueManager):
|
||||
"""Extends the queue manger with additional functionality that is needed for tests and scripts,
|
||||
but not in production, such as publishing messages.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
super().__init__(config)
|
||||
self._open_channel()
|
||||
|
||||
def publish_request(self, message: dict, properties: pika.BasicProperties = None):
|
||||
message_encoded = json.dumps(message).encode("utf-8")
|
||||
self._channel.basic_publish(
|
||||
"",
|
||||
self._input_queue,
|
||||
properties=properties,
|
||||
body=message_encoded,
|
||||
)
|
||||
|
||||
def get_response(self):
|
||||
return self._channel.basic_get(self._output_queue)
|
||||
|
||||
def clear_queues(self):
|
||||
"""purge input & output queues"""
|
||||
try:
|
||||
self._channel.queue_purge(self._input_queue)
|
||||
self._channel.queue_purge(self._output_queue)
|
||||
except pika.exceptions.ChannelWrongStateError:
|
||||
pass
|
||||
|
||||
def close_channel(self):
|
||||
self._channel.close()
|
||||
@ -2,220 +2,27 @@ import atexit
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Union, Callable
|
||||
|
||||
import pika
|
||||
import pika.exceptions
|
||||
import signal
|
||||
|
||||
from dynaconf import Dynaconf
|
||||
from kn_utils.logging import logger
|
||||
from pathlib import Path
|
||||
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.config import Config, load_settings
|
||||
from pyinfra.exception import ProcessingFailure
|
||||
from pyinfra.payload_processing.processor import PayloadProcessor
|
||||
from pyinfra.utils.dict import safe_project
|
||||
|
||||
CONFIG = Config()
|
||||
from pyinfra.utils.config_validation import validate_settings, queue_manager_validators
|
||||
|
||||
pika_logger = logging.getLogger("pika")
|
||||
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
|
||||
|
||||
|
||||
def get_connection_params(config: Config) -> pika.ConnectionParameters:
|
||||
"""creates pika connection params from pyinfra.Config class
|
||||
|
||||
Args:
|
||||
config (pyinfra.Config): standard pyinfra config class
|
||||
|
||||
Returns:
|
||||
pika.ConnectionParameters: standard pika connection param object
|
||||
"""
|
||||
credentials = pika.PlainCredentials(username=config.rabbitmq_username, password=config.rabbitmq_password)
|
||||
pika_connection_params = {
|
||||
"host": config.rabbitmq_host,
|
||||
"port": config.rabbitmq_port,
|
||||
"credentials": credentials,
|
||||
"heartbeat": config.rabbitmq_heartbeat,
|
||||
}
|
||||
|
||||
return pika.ConnectionParameters(**pika_connection_params)
|
||||
|
||||
|
||||
def _get_n_previous_attempts(props):
|
||||
return 0 if props.headers is None else props.headers.get("x-retry-count", 0)
|
||||
|
||||
|
||||
def token_file_name():
|
||||
"""create filepath
|
||||
|
||||
Returns:
|
||||
joblib.Path: filepath
|
||||
"""
|
||||
token_file_path = Path("/tmp") / "consumer_token.txt"
|
||||
return token_file_path
|
||||
|
||||
|
||||
class QueueManager:
|
||||
"""Handle RabbitMQ message reception & delivery"""
|
||||
def __init__(self, settings: Dynaconf):
|
||||
validate_settings(settings, queue_manager_validators)
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self._input_queue = config.request_queue
|
||||
self._output_queue = config.response_queue
|
||||
self._dead_letter_queue = config.dead_letter_queue
|
||||
|
||||
# controls how often we send out a life signal
|
||||
self._heartbeat = config.rabbitmq_heartbeat
|
||||
|
||||
# controls for how long we only process data events (e.g. heartbeats),
|
||||
# while the queue is blocked and we process the given callback function
|
||||
self._connection_sleep = config.rabbitmq_connection_sleep
|
||||
|
||||
self._write_token = config.write_consumer_token == "True"
|
||||
self._set_consumer_token(None)
|
||||
|
||||
self._connection_params = get_connection_params(config)
|
||||
self._connection = pika.BlockingConnection(parameters=self._connection_params)
|
||||
self._channel: BlockingChannel
|
||||
|
||||
# necessary to pods can be terminated/restarted in K8s/docker
|
||||
atexit.register(self.stop_consuming)
|
||||
signal.signal(signal.SIGTERM, self._handle_stop_signal)
|
||||
signal.signal(signal.SIGINT, self._handle_stop_signal)
|
||||
|
||||
def _set_consumer_token(self, token_value):
|
||||
self._consumer_token = token_value
|
||||
|
||||
if self._write_token:
|
||||
token_file_path = token_file_name()
|
||||
|
||||
with token_file_path.open(mode="w", encoding="utf8") as token_file:
|
||||
text = token_value if token_value is not None else ""
|
||||
token_file.write(text)
|
||||
|
||||
def _open_channel(self):
|
||||
self._channel = self._connection.channel()
|
||||
self._channel.basic_qos(prefetch_count=1)
|
||||
|
||||
args = {
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self._dead_letter_queue,
|
||||
}
|
||||
|
||||
self._channel.queue_declare(self._input_queue, arguments=args, auto_delete=False, durable=True)
|
||||
self._channel.queue_declare(self._output_queue, arguments=args, auto_delete=False, durable=True)
|
||||
|
||||
def start_consuming(self, process_payload: PayloadProcessor):
|
||||
"""consumption handling
|
||||
- standard callback handling is enforced through wrapping process_message_callback in _create_queue_callback
|
||||
(implements threading to support heartbeats)
|
||||
- initially sets consumer token to None
|
||||
- tries to
|
||||
- open channels
|
||||
- set consumer token to basic_consume, passing in the standard callback and input queue name
|
||||
- calls pika start_consuming method on the channels
|
||||
- catches all Exceptions & stops consuming + closes channels
|
||||
|
||||
Args:
|
||||
process_payload (Callable): function passed to the queue manager, configured by implementing service
|
||||
"""
|
||||
callback = self._create_queue_callback(process_payload)
|
||||
self._set_consumer_token(None)
|
||||
|
||||
try:
|
||||
self._open_channel()
|
||||
self._set_consumer_token(self._channel.basic_consume(self._input_queue, callback))
|
||||
logger.info(f"Registered with consumer-tag: {self._consumer_token}")
|
||||
self._channel.start_consuming()
|
||||
|
||||
except Exception:
|
||||
logger.error(
|
||||
"An unexpected exception occurred while consuming messages. Consuming will stop.", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
self.stop_consuming()
|
||||
self._connection.close()
|
||||
|
||||
def stop_consuming(self):
|
||||
if self._consumer_token and self._connection:
|
||||
logger.info(f"Cancelling subscription for consumer-tag {self._consumer_token}")
|
||||
self._channel.stop_consuming(self._consumer_token)
|
||||
self._set_consumer_token(None)
|
||||
|
||||
def _handle_stop_signal(self, signal_number, _stack_frame, *args, **kwargs):
|
||||
logger.info(f"Received signal {signal_number}")
|
||||
self.stop_consuming()
|
||||
|
||||
def _create_queue_callback(self, process_payload: PayloadProcessor):
|
||||
def process_message_body_and_await_result(unpacked_message_body):
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
||||
logger.debug("Processing payload in separate thread.")
|
||||
future = thread_pool_executor.submit(process_payload, unpacked_message_body)
|
||||
|
||||
while future.running():
|
||||
logger.debug("Waiting for payload processing to finish...")
|
||||
self._connection.sleep(float(self._connection_sleep))
|
||||
|
||||
try:
|
||||
return future.result()
|
||||
except Exception as err:
|
||||
raise ProcessingFailure(f"QueueMessagePayload processing failed: {repr(err)}") from err
|
||||
|
||||
def acknowledge_message_and_publish_response(frame, headers, response_body):
|
||||
response_properties = pika.BasicProperties(headers=headers) if headers else None
|
||||
self._channel.basic_publish("", self._output_queue, json.dumps(response_body).encode(), response_properties)
|
||||
logger.debug(f"Result published, acknowledging incoming message with delivery_tag {frame.delivery_tag}.")
|
||||
self._channel.basic_ack(frame.delivery_tag)
|
||||
|
||||
def callback(_channel, frame, properties, body):
|
||||
logger.info(f"Received message from queue with delivery_tag {frame.delivery_tag}.")
|
||||
logger.debug(f"Message headers: {properties.headers}")
|
||||
|
||||
# Only try to process each message once. Re-queueing will be handled by the dead-letter-exchange. This
|
||||
# prevents endless retries on messages that are impossible to process.
|
||||
if frame.redelivered:
|
||||
logger.info(
|
||||
f"Aborting message processing for delivery_tag {frame.delivery_tag} due to it being redelivered.",
|
||||
)
|
||||
self._channel.basic_nack(frame.delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
try:
|
||||
logger.debug(f"Processing {frame}, {properties}, {body}")
|
||||
filtered_message_headers = safe_project(properties.headers, ["X-TENANT-ID"])
|
||||
message_body = {**json.loads(body), **filtered_message_headers}
|
||||
|
||||
processing_result = process_message_body_and_await_result(message_body)
|
||||
logger.info(
|
||||
f"Processed message with delivery_tag {frame.delivery_tag}, publishing result to result-queue."
|
||||
)
|
||||
acknowledge_message_and_publish_response(frame, filtered_message_headers, processing_result)
|
||||
|
||||
except ProcessingFailure as err:
|
||||
logger.info(f"Processing message with delivery_tag {frame.delivery_tag} failed, declining.")
|
||||
logger.exception(err)
|
||||
self._channel.basic_nack(frame.delivery_tag, requeue=False)
|
||||
|
||||
except Exception:
|
||||
n_attempts = _get_n_previous_attempts(properties) + 1
|
||||
logger.warning(f"Failed to process message, {n_attempts}", exc_info=True)
|
||||
self._channel.basic_nack(frame.delivery_tag, requeue=False)
|
||||
raise
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
class QueueManagerV2:
|
||||
def __init__(self, settings: Dynaconf = load_settings()):
|
||||
self.input_queue = settings.rabbitmq.input_queue
|
||||
self.output_queue = settings.rabbitmq.output_queue
|
||||
self.dead_letter_queue = settings.rabbitmq.dead_letter_queue
|
||||
@ -224,9 +31,7 @@ class QueueManagerV2:
|
||||
|
||||
self.connection: Union[BlockingConnection, None] = None
|
||||
self.channel: Union[BlockingChannel, None] = None
|
||||
|
||||
self.consumer_thread: Union[threading.Thread, None] = None
|
||||
self.worker_threads: list[threading.Thread] = []
|
||||
self.connection_sleep = settings.rabbitmq.connection_sleep
|
||||
|
||||
atexit.register(self.stop_consuming)
|
||||
signal.signal(signal.SIGTERM, self._handle_stop_signal)
|
||||
@ -244,7 +49,7 @@ class QueueManagerV2:
|
||||
|
||||
return pika.ConnectionParameters(**pika_connection_params)
|
||||
|
||||
@retry(tries=5, delay=5, jitter=(1, 3))
|
||||
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger)
|
||||
def establish_connection(self):
|
||||
# TODO: set sensible retry parameters
|
||||
if self.connection and self.connection.is_open:
|
||||
@ -253,6 +58,8 @@ class QueueManagerV2:
|
||||
|
||||
logger.info("Establishing connection to RabbitMQ...")
|
||||
self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
|
||||
|
||||
logger.debug("Opening channel...")
|
||||
self.channel = self.connection.channel()
|
||||
self.channel.basic_qos(prefetch_count=1)
|
||||
|
||||
@ -263,60 +70,24 @@ class QueueManagerV2:
|
||||
|
||||
self.channel.queue_declare(self.input_queue, arguments=args, auto_delete=False, durable=True)
|
||||
self.channel.queue_declare(self.output_queue, arguments=args, auto_delete=False, durable=True)
|
||||
logger.info("Connection to RabbitMQ established.")
|
||||
|
||||
def publish_message(self, message: dict, properties: pika.BasicProperties = None):
|
||||
logger.info("Connection to RabbitMQ established, channel open.")
|
||||
|
||||
def is_ready(self):
|
||||
self.establish_connection()
|
||||
message_encoded = json.dumps(message).encode("utf-8")
|
||||
self.channel.basic_publish(
|
||||
"",
|
||||
self.input_queue,
|
||||
properties=properties,
|
||||
body=message_encoded,
|
||||
)
|
||||
logger.info(f"Published message to queue {self.input_queue}.")
|
||||
|
||||
def get_message(self):
|
||||
self.establish_connection()
|
||||
return self.channel.basic_get(self.output_queue)
|
||||
|
||||
def create_on_message_callback(self, callback: Callable):
|
||||
|
||||
def process_message_body_and_await_result(unpacked_message_body):
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
||||
logger.debug("Processing payload in separate thread.")
|
||||
future = thread_pool_executor.submit(callback, unpacked_message_body)
|
||||
|
||||
while future.running():
|
||||
logger.debug("Waiting for payload processing to finish...")
|
||||
self.connection.process_data_events()
|
||||
self.connection.sleep(5)
|
||||
|
||||
return future.result()
|
||||
|
||||
|
||||
def cb(ch, method, properties, body):
|
||||
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
|
||||
result = process_message_body_and_await_result(body)
|
||||
logger.info(f"Processed message with delivery_tag {method.delivery_tag}, publishing result to result-queue.")
|
||||
ch.basic_publish(
|
||||
"",
|
||||
self.output_queue,
|
||||
result,
|
||||
)
|
||||
|
||||
ch.basic_ack(delivery_tag=method.delivery_tag)
|
||||
logger.info(f"Message with delivery tag {method.delivery_tag} acknowledged.")
|
||||
|
||||
return cb
|
||||
return self.channel.is_open
|
||||
|
||||
def start_consuming(self, message_processor: Callable):
|
||||
on_message_callback = self.create_on_message_callback(message_processor)
|
||||
self.establish_connection()
|
||||
self.channel.basic_consume(self.input_queue, on_message_callback)
|
||||
on_message_callback = self._make_on_message_callback(message_processor)
|
||||
|
||||
try:
|
||||
self.establish_connection()
|
||||
self.channel.basic_consume(self.input_queue, on_message_callback)
|
||||
self.channel.start_consuming()
|
||||
except KeyboardInterrupt:
|
||||
except Exception:
|
||||
logger.error("An unexpected error occurred while consuming messages. Consuming will stop.", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
self.stop_consuming()
|
||||
|
||||
def stop_consuming(self):
|
||||
@ -330,14 +101,88 @@ class QueueManagerV2:
|
||||
logger.info("Closing connection to RabbitMQ...")
|
||||
self.connection.close()
|
||||
|
||||
logger.info("Waiting for worker threads to finish...")
|
||||
def publish_message_to_input_queue(self, message: Union[str, bytes, dict], properties: pika.BasicProperties = None):
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
elif isinstance(message, dict):
|
||||
message = json.dumps(message).encode("utf-8")
|
||||
|
||||
for thread in self.worker_threads:
|
||||
logger.info(f"Stopping worker thread {thread.name}...")
|
||||
thread.join()
|
||||
logger.info(f"Worker thread {thread.name} stopped.")
|
||||
self.establish_connection()
|
||||
self.channel.basic_publish(
|
||||
"",
|
||||
self.input_queue,
|
||||
properties=properties,
|
||||
body=message,
|
||||
)
|
||||
logger.info(f"Published message to queue {self.input_queue}.")
|
||||
|
||||
def purge_queues(self):
|
||||
self.establish_connection()
|
||||
try:
|
||||
self.channel.queue_purge(self.input_queue)
|
||||
self.channel.queue_purge(self.output_queue)
|
||||
logger.info("Queues purged.")
|
||||
except pika.exceptions.ChannelWrongStateError:
|
||||
pass
|
||||
|
||||
def get_message_from_output_queue(self):
|
||||
self.establish_connection()
|
||||
return self.channel.basic_get(self.output_queue, auto_ack=True)
|
||||
|
||||
def _make_on_message_callback(self, message_processor: Callable):
|
||||
def process_message_body_and_await_result(unpacked_message_body):
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
||||
logger.debug("Processing payload in separate thread.")
|
||||
future = thread_pool_executor.submit(message_processor, unpacked_message_body)
|
||||
|
||||
while future.running():
|
||||
logger.debug("Waiting for payload processing to finish...")
|
||||
self.connection.process_data_events()
|
||||
self.connection.sleep(self.connection_sleep)
|
||||
|
||||
return future.result()
|
||||
|
||||
def on_message_callback(channel, method, properties, body):
|
||||
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
|
||||
|
||||
if method.redelivered:
|
||||
logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.")
|
||||
channel.basic_nack(method.delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
if body.decode("utf-8") == "STOP":
|
||||
logger.info(f"Received stop signal, stopping consuming...")
|
||||
channel.basic_ack(delivery_tag=method.delivery_tag)
|
||||
self.stop_consuming()
|
||||
return
|
||||
|
||||
try:
|
||||
filtered_message_headers = (
|
||||
{k: v for k, v in properties.headers.items() if k.lower().startswith("x-")}
|
||||
if properties.headers
|
||||
else {}
|
||||
)
|
||||
logger.debug(f"Processing message with {filtered_message_headers=}.")
|
||||
result = process_message_body_and_await_result({**json.loads(body), **filtered_message_headers})
|
||||
|
||||
channel.basic_publish(
|
||||
"",
|
||||
self.output_queue,
|
||||
result,
|
||||
properties=pika.BasicProperties(headers=filtered_message_headers),
|
||||
)
|
||||
logger.info(f"Published result to queue {self.output_queue}.")
|
||||
|
||||
channel.basic_ack(delivery_tag=method.delivery_tag)
|
||||
logger.debug(f"Message with {method.delivery_tag=} acknowledged.")
|
||||
except Exception:
|
||||
logger.warning(f"Failed to process message with {method.delivery_tag=}, declining...", exc_info=True)
|
||||
channel.basic_nack(method.delivery_tag, requeue=False)
|
||||
raise
|
||||
|
||||
return on_message_callback
|
||||
|
||||
def _handle_stop_signal(self, signum, *args, **kwargs):
|
||||
logger.info(f"Received signal {signum}, stopping consuming...")
|
||||
self.stop_consuming()
|
||||
sys.exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
30
pyinfra/utils/config_validation.py
Normal file
30
pyinfra/utils/config_validation.py
Normal file
@ -0,0 +1,30 @@
|
||||
from dynaconf import Validator, Dynaconf, ValidationError
|
||||
from kn_utils.logging import logger
|
||||
|
||||
queue_manager_validators = [
|
||||
Validator("rabbitmq.host", must_exist=True),
|
||||
Validator("rabbitmq.port", must_exist=True),
|
||||
Validator("rabbitmq.username", must_exist=True),
|
||||
Validator("rabbitmq.password", must_exist=True),
|
||||
Validator("rabbitmq.heartbeat", must_exist=True),
|
||||
Validator("rabbitmq.connection_sleep", must_exist=True),
|
||||
Validator("rabbitmq.input_queue", must_exist=True),
|
||||
Validator("rabbitmq.output_queue", must_exist=True),
|
||||
Validator("rabbitmq.dead_letter_queue", must_exist=True),
|
||||
]
|
||||
|
||||
|
||||
def validate_settings(settings: Dynaconf, validators):
|
||||
settings_valid = True
|
||||
|
||||
for validator in validators:
|
||||
try:
|
||||
validator.validate(settings)
|
||||
except ValidationError as e:
|
||||
settings_valid = False
|
||||
logger.warning(e)
|
||||
|
||||
if not settings_valid:
|
||||
raise ValidationError("Settings validation failed.")
|
||||
|
||||
logger.info("Settings validated.")
|
||||
@ -2,13 +2,13 @@ import gzip
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from pyinfra.config import get_config
|
||||
from pyinfra.config import get_config, load_settings
|
||||
from pyinfra.payload_processing.payload import LegacyQueueMessagePayload, QueueMessagePayload
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def settings():
|
||||
return get_config()
|
||||
return load_settings()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -1,46 +1,102 @@
|
||||
import json
|
||||
from multiprocessing import Process
|
||||
from sys import stdout
|
||||
from time import sleep
|
||||
|
||||
import pika
|
||||
import pytest
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.config import get_config
|
||||
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
|
||||
from pyinfra.queue.queue_manager import QueueManager, QueueManagerV2
|
||||
from pyinfra.queue.queue_manager import QueueManager
|
||||
|
||||
logger.remove()
|
||||
logger.add(sink=stdout, level="DEBUG")
|
||||
|
||||
|
||||
def callback(x):
|
||||
sleep(4)
|
||||
response = json.dumps({"status": "success"}).encode("utf-8")
|
||||
return response
|
||||
def make_callback(process_time):
|
||||
def callback(x):
|
||||
sleep(process_time)
|
||||
return json.dumps({"status": "success"}).encode("utf-8")
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def queue_manager(settings):
|
||||
settings.rabbitmq_heartbeat = 10
|
||||
settings.connection_sleep = 5
|
||||
queue_manager = QueueManager(settings)
|
||||
yield queue_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def input_message():
|
||||
return json.dumps({
|
||||
"targetFilePath": "test/target.json.gz",
|
||||
"responseFilePath": "test/response.json.gz",
|
||||
})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stop_message():
|
||||
return "STOP"
|
||||
|
||||
|
||||
class TestQueueManager:
|
||||
def test_basic_functionality(self, settings):
|
||||
message = {
|
||||
"targetFilePath": "test/target.json.gz",
|
||||
"responseFilePath": "test/response.json.gz",
|
||||
}
|
||||
def test_processing_of_several_messages(self, queue_manager, input_message, stop_message):
|
||||
queue_manager.purge_queues()
|
||||
|
||||
queue_manager = QueueManagerV2()
|
||||
# queue_manager_old = QueueManager(get_config())
|
||||
for _ in range(2):
|
||||
queue_manager.publish_message_to_input_queue(input_message)
|
||||
|
||||
queue_manager.publish_message(message)
|
||||
queue_manager.publish_message(message)
|
||||
queue_manager.publish_message(message)
|
||||
logger.info("Published message")
|
||||
queue_manager.publish_message_to_input_queue(stop_message)
|
||||
|
||||
# consume = lambda: queue_manager.start_consuming(callback)
|
||||
consume = lambda: queue_manager.start_consuming(callback)
|
||||
p = Process(target=consume)
|
||||
p.start()
|
||||
callback = make_callback(1)
|
||||
queue_manager.start_consuming(callback)
|
||||
|
||||
wait_time = 20
|
||||
# logger.info(f"Waiting {wait_time} seconds for the consumer to process the message...")
|
||||
sleep(wait_time)
|
||||
for _ in range(2):
|
||||
response = queue_manager.get_message_from_output_queue()
|
||||
assert response is not None
|
||||
assert response[2] == b'{"status": "success"}'
|
||||
print(response)
|
||||
|
||||
p.kill()
|
||||
def test_all_headers_beginning_with_x_are_forwarded(self, queue_manager, input_message, stop_message):
|
||||
queue_manager.purge_queues()
|
||||
|
||||
response = queue_manager.get_message()
|
||||
properties = pika.BasicProperties(
|
||||
headers={
|
||||
"X-TENANT-ID": "redaction",
|
||||
"X-OTHER-HEADER": "other-header-value",
|
||||
"x-tenant_id": "tenant-id-value",
|
||||
"x_should_not_be_forwarded": "should-not-be-forwarded-value",
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Response: {response}")
|
||||
queue_manager.publish_message_to_input_queue(input_message, properties=properties)
|
||||
queue_manager.publish_message_to_input_queue(stop_message)
|
||||
|
||||
callback = make_callback(0.2)
|
||||
queue_manager.start_consuming(callback)
|
||||
|
||||
response = queue_manager.get_message_from_output_queue()
|
||||
print(response)
|
||||
|
||||
assert response[2] == b'{"status": "success"}'
|
||||
|
||||
assert response[1].headers["X-TENANT-ID"] == "redaction"
|
||||
assert response[1].headers["X-OTHER-HEADER"] == "other-header-value"
|
||||
assert response[1].headers["x-tenant_id"] == "tenant-id-value"
|
||||
|
||||
assert "x_should_not_be_forwarded" not in response[1].headers
|
||||
|
||||
def test_message_processing_does_not_block_heartbeat(self, queue_manager, input_message, stop_message):
|
||||
queue_manager.purge_queues()
|
||||
|
||||
queue_manager.publish_message_to_input_queue(input_message)
|
||||
queue_manager.publish_message_to_input_queue(stop_message)
|
||||
|
||||
callback = make_callback(15)
|
||||
queue_manager.start_consuming(callback)
|
||||
|
||||
response = queue_manager.get_message_from_output_queue()
|
||||
|
||||
assert response[2] == b'{"status": "success"}'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user