From adfbd650e61d7b74fcc9c9d1deb6f0967fe9ff7b Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Tue, 23 Jan 2024 08:51:44 +0100 Subject: [PATCH] Add config tests, add type validation to config loading --- pyinfra/config/loader.py | 25 ++++++++++-------- pyinfra/config/validators.py | 46 +++++++++++++++++----------------- scripts/send_request.py | 5 ++-- scripts/start_pyinfra.py | 4 +-- tests/conftest.py | 4 +-- tests/unit_test/config_test.py | 36 ++++++++++++++++++++++++++ 6 files changed, 79 insertions(+), 41 deletions(-) create mode 100644 tests/unit_test/config_test.py diff --git a/pyinfra/config/loader.py b/pyinfra/config/loader.py index 56658f9..1e10c44 100644 --- a/pyinfra/config/loader.py +++ b/pyinfra/config/loader.py @@ -2,21 +2,19 @@ import os from pathlib import Path from typing import Union -from dynaconf import Dynaconf, ValidationError +from dynaconf import Dynaconf, ValidationError, Validator from funcy import lflatten from kn_utils.logging import logger -def load_settings(settings_path: Union[str, Path] = None): +def load_settings(settings_path: Union[str, Path] = None, validators: list[Validator] = None): + settings_path = Path(settings_path) if settings_path else None + validators = validators or get_all_validators() if not settings_path: - repo_root_path = Path(__file__).resolve().parents[2] - settings_path = repo_root_path / "config/" - logger.info(f"No settings path provided, using relative settings path: {settings_path}") - - settings_path = Path(settings_path) - - if os.path.isdir(settings_path): + logger.info("No settings path specified, only loading .env end ENVs.") + settings_files = [] + elif os.path.isdir(settings_path): logger.info(f"Settings path is a directory, loading all .toml files in the directory: {settings_path}") settings_files = list(settings_path.glob("*.toml")) else: @@ -29,15 +27,20 @@ def load_settings(settings_path: Union[str, Path] = None): settings_files=settings_files, ) - validate_settings(settings, get_all_validators()) + validate_settings(settings, validators) return settings +pyinfra_config_path = Path(__file__).resolve().parents[2] / "config/" + + def get_all_validators(): import pyinfra.config.validators - return lflatten(validator for validator in pyinfra.config.validators.__dict__.values() if isinstance(validator, list)) + return lflatten( + validator for validator in pyinfra.config.validators.__dict__.values() if isinstance(validator, list) + ) def validate_settings(settings: Dynaconf, validators): diff --git a/pyinfra/config/validators.py b/pyinfra/config/validators.py index bdc6204..e0ae2af 100644 --- a/pyinfra/config/validators.py +++ b/pyinfra/config/validators.py @@ -1,46 +1,46 @@ from dynaconf import Validator queue_manager_validators = [ - Validator("rabbitmq.host", must_exist=True), - Validator("rabbitmq.port", must_exist=True), - Validator("rabbitmq.username", must_exist=True), - Validator("rabbitmq.password", must_exist=True), - Validator("rabbitmq.heartbeat", must_exist=True), - Validator("rabbitmq.connection_sleep", must_exist=True), - Validator("rabbitmq.input_queue", must_exist=True), - Validator("rabbitmq.output_queue", must_exist=True), - Validator("rabbitmq.dead_letter_queue", must_exist=True), + Validator("rabbitmq.host", must_exist=True, is_type_of=str), + Validator("rabbitmq.port", must_exist=True, is_type_of=int), + Validator("rabbitmq.username", must_exist=True, is_type_of=str), + Validator("rabbitmq.password", must_exist=True, is_type_of=str), + Validator("rabbitmq.heartbeat", must_exist=True, is_type_of=int), + Validator("rabbitmq.connection_sleep", must_exist=True, is_type_of=int), + Validator("rabbitmq.input_queue", must_exist=True, is_type_of=str), + Validator("rabbitmq.output_queue", must_exist=True, is_type_of=str), + Validator("rabbitmq.dead_letter_queue", must_exist=True, is_type_of=str), ] azure_storage_validators = [ - Validator("storage.azure.connection_string", must_exist=True), - Validator("storage.azure.container", must_exist=True), + Validator("storage.azure.connection_string", must_exist=True, is_type_of=str), + Validator("storage.azure.container", must_exist=True, is_type_of=str), ] s3_storage_validators = [ - Validator("storage.s3.endpoint", must_exist=True), - Validator("storage.s3.key", must_exist=True), - Validator("storage.s3.secret", must_exist=True), - Validator("storage.s3.region", must_exist=True), - Validator("storage.s3.bucket", must_exist=True), + Validator("storage.s3.endpoint", must_exist=True, is_type_of=str), + Validator("storage.s3.key", must_exist=True, is_type_of=str), + Validator("storage.s3.secret", must_exist=True, is_type_of=str), + Validator("storage.s3.region", must_exist=True, is_type_of=str), + Validator("storage.s3.bucket", must_exist=True, is_type_of=str), ] storage_validators = [ - Validator("storage.backend", must_exist=True), + Validator("storage.backend", must_exist=True, is_type_of=str), ] multi_tenant_storage_validators = [ - Validator("storage.tenant_server.endpoint", must_exist=True), - Validator("storage.tenant_server.public_key", must_exist=True), + Validator("storage.tenant_server.endpoint", must_exist=True, is_type_of=str), + Validator("storage.tenant_server.public_key", must_exist=True, is_type_of=str), ] prometheus_validators = [ - Validator("metrics.prometheus.prefix", must_exist=True), - Validator("metrics.prometheus.enabled", must_exist=True), + Validator("metrics.prometheus.prefix", must_exist=True, is_type_of=str), + Validator("metrics.prometheus.enabled", must_exist=True, is_type_of=bool), ] webserver_validators = [ - Validator("webserver.host", must_exist=True), - Validator("webserver.port", must_exist=True), + Validator("webserver.host", must_exist=True, is_type_of=str), + Validator("webserver.port", must_exist=True, is_type_of=int), ] diff --git a/scripts/send_request.py b/scripts/send_request.py index 79e6a4c..5e464fb 100644 --- a/scripts/send_request.py +++ b/scripts/send_request.py @@ -2,14 +2,13 @@ import gzip import json from operator import itemgetter -import pika from kn_utils.logging import logger -from pyinfra.config.loader import load_settings +from pyinfra.config.loader import load_settings, pyinfra_config_path from pyinfra.queue.manager import QueueManager from pyinfra.storage.storages.s3 import get_s3_storage_from_settings -settings = load_settings() +settings = load_settings(pyinfra_config_path) def upload_json_and_make_message_body(): diff --git a/scripts/start_pyinfra.py b/scripts/start_pyinfra.py index 29874ea..013367b 100644 --- a/scripts/start_pyinfra.py +++ b/scripts/start_pyinfra.py @@ -2,7 +2,7 @@ import argparse import time from pathlib import Path -from pyinfra.config.loader import load_settings +from pyinfra.config.loader import load_settings, pyinfra_config_path from pyinfra.examples import start_queue_consumer_with_prometheus_and_health_endpoints @@ -12,7 +12,7 @@ def parse_args(): "--settings_path", "-s", type=Path, - default=None, + default=pyinfra_config_path, help="Path to settings file or folder. Must be a .toml file or a folder containing .toml files.", ) return parser.parse_args() diff --git a/tests/conftest.py b/tests/conftest.py index 81750c8..4939c83 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,12 @@ import pytest -from pyinfra.config.loader import load_settings +from pyinfra.config.loader import load_settings, pyinfra_config_path from pyinfra.storage.connection import get_storage_from_settings @pytest.fixture(scope="session") def settings(): - return load_settings() + return load_settings(pyinfra_config_path) @pytest.fixture(scope="class") diff --git a/tests/unit_test/config_test.py b/tests/unit_test/config_test.py new file mode 100644 index 0000000..363dbac --- /dev/null +++ b/tests/unit_test/config_test.py @@ -0,0 +1,36 @@ +import os + +import pytest +from dynaconf import Validator + +from pyinfra.config.loader import load_settings +from pyinfra.config.validators import webserver_validators + + +@pytest.fixture +def test_validators(): + return [ + Validator("test.value.int", must_exist=True, is_type_of=int), + Validator("test.value.str", must_exist=True, is_type_of=str), + ] + + +class TestConfig: + def test_config_validation(self): + os.environ["WEBSERVER__HOST"] = "localhost" + os.environ["WEBSERVER__PORT"] = "8080" + + validators = webserver_validators + + test_settings = load_settings(validators=validators) + + assert test_settings.webserver.host == "localhost" + + def test_env_into_correct_type_conversion(self, test_validators): + os.environ["TEST__VALUE__INT"] = "1" + os.environ["TEST__VALUE__STR"] = "test" + + test_settings = load_settings(validators=test_validators) + + assert test_settings.test.value.int == 1 + assert test_settings.test.value.str == "test"