From 8867da35576a4cbf286540626413dddde303a16a Mon Sep 17 00:00:00 2001 From: Viktor Seifert Date: Mon, 1 Aug 2022 17:00:44 +0200 Subject: [PATCH] RED-4653: Added value to config to prevent writing the token as a default since that is only useful in a container --- pyinfra/config.py | 3 +++ pyinfra/queue/queue_manager.py | 14 +++++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pyinfra/config.py b/pyinfra/config.py index ee1efa7..89fbce3 100644 --- a/pyinfra/config.py +++ b/pyinfra/config.py @@ -54,6 +54,9 @@ class Config(object): "STORAGE_AZURECONNECTIONSTRING", "DefaultEndpointsProtocol=..." ) + # Value to see if we should write a consumer token to a file + self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False") + def get_config() -> Config: return Config() diff --git a/pyinfra/queue/queue_manager.py b/pyinfra/queue/queue_manager.py index 3d46b09..3589998 100644 --- a/pyinfra/queue/queue_manager.py +++ b/pyinfra/queue/queue_manager.py @@ -31,7 +31,7 @@ def _get_n_previous_attempts(props): def token_file_name(): - token_file_path = Path(__file__).parent / "consumer_token.txt" + token_file_path = Path("/var/run") / "consumer_token.txt" return token_file_path @@ -40,6 +40,8 @@ class QueueManager(object): self.logger = logging.getLogger("queue_manager") self.logger.setLevel(config.logging_level_root) + self._write_token = config.write_consumer_token == "True" + self._set_consumer_token(None) self._connection_params = get_connection_params(config) @@ -54,11 +56,13 @@ class QueueManager(object): def _set_consumer_token(self, token_value): self._consumer_token = token_value - token_file_path = token_file_name() - with token_file_path.open(mode="w", encoding="utf8") as token_file: - text = token_value is not None if token_value else "" - token_file.write(text) + if self._write_token: + token_file_path = token_file_name() + + with token_file_path.open(mode="w", encoding="utf8") as token_file: + text = token_value is not None if token_value else "" + token_file.write(text) def _open_channel(self): self._connection = pika.BlockingConnection(parameters=self._connection_params)