From eed59125168d4117b3c3fb1078dacf6cabace4f3 Mon Sep 17 00:00:00 2001 From: Viktor Seifert Date: Mon, 1 Aug 2022 16:19:13 +0200 Subject: [PATCH] RED-4653: Implemented a startup probe for k8s --- pyinfra/k8s_probes/__init__.py | 0 pyinfra/k8s_probes/startup.py | 36 ++++++++++++++++++++++++++++++++++ pyinfra/queue/queue_manager.py | 23 +++++++++++++++++----- 3 files changed, 54 insertions(+), 5 deletions(-) create mode 100644 pyinfra/k8s_probes/__init__.py create mode 100644 pyinfra/k8s_probes/startup.py diff --git a/pyinfra/k8s_probes/__init__.py b/pyinfra/k8s_probes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyinfra/k8s_probes/startup.py b/pyinfra/k8s_probes/startup.py new file mode 100644 index 0000000..7a10a24 --- /dev/null +++ b/pyinfra/k8s_probes/startup.py @@ -0,0 +1,36 @@ +import logging +from pathlib import Path +import sys + +from pyinfra.queue.queue_manager import token_file_name + + +def check_token_file(): + """ + Checks if the token file of the QueueManager exists and is not empty, i.e. the queue manager has been started. + + NOTE: This function suppresses all Exception's. + + Returns True if the queue manager has been started, False otherwise + """ + + try: + token_file_path = Path(token_file_name()) + + if token_file_path.exists(): + with token_file_path.open(mode="r", encoding="utf8") as token_file: + contents = token_file.read().strip() + + return contents != "" + # We're intentionally do not handle exception here, since we're only using this in a short script. + # Take care to expand this if the intended use changes + except Exception: + logging.getLogger(__file__).info("Caught exception when reading from token file", exc_info=True) + return False + + +if __name__ == '__main__': + if check_token_file(): + sys.exit(0) + else: + sys.exit(1) diff --git a/pyinfra/queue/queue_manager.py b/pyinfra/queue/queue_manager.py index c599fff..3d46b09 100644 --- a/pyinfra/queue/queue_manager.py +++ b/pyinfra/queue/queue_manager.py @@ -3,7 +3,7 @@ import json import logging import signal from typing import Callable -from os import environ +from pathlib import Path import pika import pika.exceptions @@ -30,12 +30,17 @@ def _get_n_previous_attempts(props): return 0 if props.headers is None else props.headers.get("x-retry-count", 0) +def token_file_name(): + token_file_path = Path(__file__).parent / "consumer_token.txt" + return token_file_path + + class QueueManager(object): def __init__(self, config: Config): self.logger = logging.getLogger("queue_manager") self.logger.setLevel(config.logging_level_root) - self._consumer_token = None + self._set_consumer_token(None) self._connection_params = get_connection_params(config) @@ -47,6 +52,14 @@ class QueueManager(object): signal.signal(signal.SIGTERM, self._handle_stop_signal) signal.signal(signal.SIGINT, self._handle_stop_signal) + def _set_consumer_token(self, token_value): + self._consumer_token = token_value + token_file_path = token_file_name() + + with token_file_path.open(mode="w", encoding="utf8") as token_file: + text = token_value is not None if token_value else "" + token_file.write(text) + def _open_channel(self): self._connection = pika.BlockingConnection(parameters=self._connection_params) self._channel = self._connection.channel() @@ -64,13 +77,13 @@ class QueueManager(object): def start_consuming(self, process_message_callback: Callable): callback = self._create_queue_callback(process_message_callback) - self._consumer_token = None + self._set_consumer_token(None) self.logger.info("Consuming from queue") try: self._open_channel() - self._consumer_token = self._channel.basic_consume(self._input_queue, callback) + self._set_consumer_token(self._channel.basic_consume(self._input_queue, callback)) self.logger.info(f"Registered with consumer-tag: {self._consumer_token}") self._channel.start_consuming() except Exception: @@ -86,7 +99,7 @@ class QueueManager(object): if self._consumer_token and self._connection: self.logger.info(f"Cancelling subscription for consumer-tag: {self._consumer_token}") self._channel.stop_consuming(self._consumer_token) - self._consumer_token = None + self._set_consumer_token(None) def _handle_stop_signal(self, signal_number, _stack_frame, *args, **kwargs): self.logger.info(f"Received signal {signal_number}")