From ec5ad09fa8d07ad9f9b51531d6971e9cf6046a5a Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Thu, 18 Jan 2024 11:22:27 +0100 Subject: [PATCH] refactor: multi tenant storage connection --- pyinfra/payload_processing/monitor.py | 57 -------- pyinfra/storage/connection.py | 124 ++++++++++++++++ pyinfra/storage/storage.py | 51 ------- pyinfra/storage/storage_info.py | 125 ---------------- pyinfra/storage/storage_provider.py | 55 ------- pyinfra/storage/storages/azure.py | 57 ++++---- pyinfra/storage/storages/interface.py | 21 +-- pyinfra/storage/storages/mock.py | 31 ++-- pyinfra/storage/storages/s3.py | 54 +++---- pyinfra/utils/config_validation.py | 8 ++ pyinfra/webserver.py | 12 +- .../prometheus_monitoring_test.py | 4 +- .../tests_with_docker_compose/storage_test.py | 135 ++++++++++++------ 13 files changed, 328 insertions(+), 406 deletions(-) delete mode 100644 pyinfra/payload_processing/monitor.py create mode 100644 pyinfra/storage/connection.py delete mode 100644 pyinfra/storage/storage.py delete mode 100644 pyinfra/storage/storage_info.py delete mode 100644 pyinfra/storage/storage_provider.py diff --git a/pyinfra/payload_processing/monitor.py b/pyinfra/payload_processing/monitor.py deleted file mode 100644 index 5ea2d94..0000000 --- a/pyinfra/payload_processing/monitor.py +++ /dev/null @@ -1,57 +0,0 @@ -from funcy import identity -from operator import attrgetter -from prometheus_client import Summary, start_http_server, CollectorRegistry -from time import time -from typing import Callable, Any, Sized - -from pyinfra.config import Config - - -class PrometheusMonitor: - def __init__(self, prefix: str, host: str, port: int): - """Register the monitoring metrics and start a webserver where they can be scraped at the endpoint - http://{host}:{port}/prometheus - - Args: - prefix: should per convention consist of {product_name}_{service_name}_{parameter_to_monitor} - parameter_to_monitor is defined by the result of the processing service. - """ - self.registry = CollectorRegistry() - - self.entity_processing_time_sum = Summary( - f"{prefix}_processing_time", "Summed up average processing time per entity observed", registry=self.registry - ) - - start_http_server(port, host, self.registry) - - def __call__(self, process_fn: Callable) -> Callable: - """Monitor the runtime of a function and update the registered metric with the average runtime per resulting - element. - """ - return self._add_result_monitoring(process_fn) - - def _add_result_monitoring(self, process_fn: Callable): - def inner(data: Any, **kwargs): - start = time() - - result: Sized = process_fn(data, **kwargs) - - runtime = time() - start - - if not result: - return result - - processing_time_per_entity = runtime / len(result) - - self.entity_processing_time_sum.observe(processing_time_per_entity) - - return result - - return inner - - -def get_monitor_from_config(config: Config) -> Callable: - if config.monitoring_enabled: - return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_host", "prometheus_port")(config)) - else: - return identity diff --git a/pyinfra/storage/connection.py b/pyinfra/storage/connection.py new file mode 100644 index 0000000..9c06349 --- /dev/null +++ b/pyinfra/storage/connection.py @@ -0,0 +1,124 @@ +from functools import lru_cache, partial +from typing import Callable + +import requests +from dynaconf import Dynaconf +from funcy import compose +from kn_utils.logging import logger + +from pyinfra.storage.storages.azure import get_azure_storage_from_settings +from pyinfra.storage.storages.interface import Storage +from pyinfra.storage.storages.s3 import get_s3_storage_from_settings +from pyinfra.utils.cipher import decrypt +from pyinfra.utils.compressing import get_decompressor, get_compressor +from pyinfra.utils.config_validation import validate_settings, storage_validators, multi_tenant_storage_validators +from pyinfra.utils.encoding import get_decoder, get_encoder + + +def get_storage(settings: Dynaconf, tenant_id: str = None) -> Storage: + """Get storage connection based on settings. + If tenant_id is provided, gets storage connection information from tenant server instead. + The connections are cached based on the settings.cache_size value. + + In the future, when the default storage from config is no longer needed (only multi-tenant storage will be used), + get_storage_from_tenant_id can replace this function directly. + """ + if tenant_id: + logger.info(f"Using tenant storage for {tenant_id}.") + return get_storage_from_tenant_id(tenant_id, settings) + else: + logger.info("Using default storage.") + return get_storage_from_settings(settings) + + +def get_storage_from_settings(settings: Dynaconf) -> Storage: + validate_settings(settings, storage_validators) + + @lru_cache(maxsize=settings.storage.cache_size) + def _get_storage(backend: str) -> Storage: + return storage_dispatcher[backend](settings) + + return _get_storage(settings.storage.backend) + + +def get_storage_from_tenant_id(tenant_id: str, settings: Dynaconf) -> Storage: + validate_settings(settings, multi_tenant_storage_validators) + + @lru_cache(maxsize=settings.storage.cache_size) + def _get_storage(tenant: str, endpoint: str, public_key: str) -> Storage: + response = requests.get(f"{endpoint}/{tenant}").json() + + maybe_azure = response.get("azureStorageConnection") + maybe_s3 = response.get("s3StorageConnection") + assert (maybe_azure or maybe_s3) and not (maybe_azure and maybe_s3), "Only one storage backend can be used." + + if maybe_azure: + connection_string = decrypt(public_key, maybe_azure["connectionString"]) + backend = "azure" + storage_settings = { + "storage": { + "azure": { + "connection_string": connection_string, + "container": maybe_azure["containerName"], + }, + } + } + elif maybe_s3: + secret = decrypt(public_key, maybe_s3["secret"]) + backend = "s3" + storage_settings = { + "storage": { + "s3": { + "endpoint": maybe_s3["endpoint"], + "key": maybe_s3["key"], + "secret": secret, + "region": maybe_s3["region"], + "bucket": maybe_s3["bucketName"], + }, + } + } + else: + raise Exception(f"Unknown storage backend in {response}.") + + storage_settings = Dynaconf() + storage_settings.update(settings) + + storage = storage_dispatcher[backend](storage_settings) + + return storage + + return _get_storage(tenant_id, settings.storage.tenant_server.endpoint, settings.storage.tenant_server.public_key) + + +storage_dispatcher = { + "azure": get_azure_storage_from_settings, + "s3": get_s3_storage_from_settings, +} + + +@lru_cache(maxsize=10) +def make_downloader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable: + verify = partial(verify_existence, storage, bucket) + download = partial(storage.get_object, bucket) + decompress = get_decompressor(compression_type) + decode = get_decoder(file_type) + + return compose(decode, decompress, download, verify) + + +@lru_cache(maxsize=10) +def make_uploader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable: + upload = partial(storage.put_object, bucket) + compress = get_compressor(compression_type) + encode = get_encoder(file_type) + + def inner(file_name, file_bytes): + upload(file_name, compose(compress, encode)(file_bytes)) + + return inner + + +def verify_existence(storage: Storage, bucket: str, file_name: str) -> str: + if not storage.exists(file_name): + raise FileNotFoundError(f"{file_name=} name not found on storage in {storage.bucket=}.") + return file_name diff --git a/pyinfra/storage/storage.py b/pyinfra/storage/storage.py deleted file mode 100644 index c452f0d..0000000 --- a/pyinfra/storage/storage.py +++ /dev/null @@ -1,51 +0,0 @@ -from functools import lru_cache, partial -from typing import Callable - -from dynaconf import Dynaconf -from funcy import compose - -from pyinfra.storage.storages.interface import Storage -from pyinfra.storage.storages.s3 import get_s3_storage_from_settings -from pyinfra.utils.compressing import get_decompressor, get_compressor -from pyinfra.utils.config_validation import validate_settings, storage_validators -from pyinfra.utils.encoding import get_decoder, get_encoder - - -def get_storage_from_settings(settings: Dynaconf) -> Storage: - validate_settings(settings, storage_validators) - - return storage_dispatcher[settings.storage.backend](settings) - - -storage_dispatcher = { - "azure": get_s3_storage_from_settings, - "s3": get_s3_storage_from_settings, -} - - -@lru_cache(maxsize=10) -def make_downloader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable: - verify = partial(verify_existence, storage, bucket) - download = partial(storage.get_object, bucket) - decompress = get_decompressor(compression_type) - decode = get_decoder(file_type) - - return compose(decode, decompress, download, verify) - - -@lru_cache(maxsize=10) -def make_uploader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable: - upload = partial(storage.put_object, bucket) - compress = get_compressor(compression_type) - encode = get_encoder(file_type) - - def inner(file_name, file_bytes): - upload(file_name, compose(compress, encode)(file_bytes)) - - return inner - - -def verify_existence(storage: Storage, bucket: str, file_name: str) -> str: - if not storage.exists(bucket, file_name): - raise FileNotFoundError(f"{file_name=} name not found on storage in {bucket=}.") - return file_name diff --git a/pyinfra/storage/storage_info.py b/pyinfra/storage/storage_info.py deleted file mode 100644 index aefa15b..0000000 --- a/pyinfra/storage/storage_info.py +++ /dev/null @@ -1,125 +0,0 @@ -from dataclasses import dataclass - -import requests -from azure.storage.blob import BlobServiceClient -from minio import Minio - -from pyinfra.config import Config -from pyinfra.exception import UnknownStorageBackend -from pyinfra.storage.storages.azure import AzureStorage -from pyinfra.storage.storages.interface import Storage -from pyinfra.storage.storages.s3 import S3Storage -from pyinfra.utils.cipher import decrypt -from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint - - -@dataclass(frozen=True) -class StorageInfo: - bucket_name: str - - -@dataclass(frozen=True) -class AzureStorageInfo(StorageInfo): - connection_string: str - - def __hash__(self): - return hash(self.connection_string) - - def __eq__(self, other): - if not isinstance(other, AzureStorageInfo): - return False - return self.connection_string == other.connection_string - - -@dataclass(frozen=True) -class S3StorageInfo(StorageInfo): - secure: bool - endpoint: str - access_key: str - secret_key: str - region: str - - def __hash__(self): - return hash((self.secure, self.endpoint, self.access_key, self.secret_key, self.region)) - - def __eq__(self, other): - if not isinstance(other, S3StorageInfo): - return False - return ( - self.secure == other.secure - and self.endpoint == other.endpoint - and self.access_key == other.access_key - and self.secret_key == other.secret_key - and self.region == other.region - ) - - -def get_storage_from_storage_info(storage_info: StorageInfo) -> Storage: - if isinstance(storage_info, AzureStorageInfo): - return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string)) - elif isinstance(storage_info, S3StorageInfo): - return S3Storage( - Minio( - secure=storage_info.secure, - endpoint=storage_info.endpoint, - access_key=storage_info.access_key, - secret_key=storage_info.secret_key, - region=storage_info.region, - ) - ) - else: - raise UnknownStorageBackend() - - -def get_storage_info_from_endpoint(public_key: str, endpoint: str, x_tenant_id: str) -> StorageInfo: - resp = requests.get(f"{endpoint}/{x_tenant_id}").json() - - maybe_azure = resp.get("azureStorageConnection") - maybe_s3 = resp.get("s3StorageConnection") - assert not (maybe_azure and maybe_s3) - - if maybe_azure: - connection_string = decrypt(public_key, maybe_azure["connectionString"]) - storage_info = AzureStorageInfo( - connection_string=connection_string, - bucket_name=maybe_azure["containerName"], - ) - elif maybe_s3: - secure, endpoint = validate_and_parse_s3_endpoint(maybe_s3["endpoint"]) - secret = decrypt(public_key, maybe_s3["secret"]) - - storage_info = S3StorageInfo( - secure=secure, - endpoint=endpoint, - access_key=maybe_s3["key"], - secret_key=secret, - region=maybe_s3["region"], - bucket_name=maybe_s3["bucketName"], - ) - else: - raise UnknownStorageBackend() - - return storage_info - - -def get_storage_info_from_config(config: Config) -> StorageInfo: - if config.storage_backend == "s3": - storage_info = S3StorageInfo( - secure=config.storage_secure_connection, - endpoint=config.storage_endpoint, - access_key=config.storage_key, - secret_key=config.storage_secret, - region=config.storage_region, - bucket_name=config.storage_bucket, - ) - - elif config.storage_backend == "azure": - storage_info = AzureStorageInfo( - connection_string=config.storage_azureconnectionstring, - bucket_name=config.storage_bucket, - ) - - else: - raise UnknownStorageBackend(f"Unknown storage backend '{config.storage_backend}'.") - - return storage_info diff --git a/pyinfra/storage/storage_provider.py b/pyinfra/storage/storage_provider.py deleted file mode 100644 index 345a096..0000000 --- a/pyinfra/storage/storage_provider.py +++ /dev/null @@ -1,55 +0,0 @@ -from dataclasses import asdict -from functools import partial, lru_cache -from kn_utils.logging import logger -from typing import Tuple - -from pyinfra.config import Config -from pyinfra.storage.storage_info import ( - get_storage_info_from_config, - get_storage_info_from_endpoint, - StorageInfo, - get_storage_from_storage_info, -) -from pyinfra.storage.storages.interface import Storage - - -class StorageProvider: - def __init__(self, config: Config): - self.config = config - self.default_storage_info: StorageInfo = get_storage_info_from_config(config) - - self.get_storage_info_from_tenant_id = partial( - get_storage_info_from_endpoint, - config.tenant_decryption_public_key, - config.tenant_endpoint, - ) - - def __call__(self, *args, **kwargs): - return self._connect(*args, **kwargs) - - @lru_cache(maxsize=32) - def _connect(self, x_tenant_id=None) -> Tuple[Storage, StorageInfo]: - storage_info = self._get_storage_info(x_tenant_id) - storage_connection = get_storage_from_storage_info(storage_info) - return storage_connection, storage_info - - def _get_storage_info(self, x_tenant_id=None): - if x_tenant_id: - storage_info = self.get_storage_info_from_tenant_id(x_tenant_id) - logger.debug(f"Received {storage_info.__class__.__name__} for {x_tenant_id} from endpoint.") - logger.trace(f"{asdict(storage_info)}") - else: - storage_info = self.default_storage_info - logger.debug(f"Using local default {storage_info.__class__.__name__} for {x_tenant_id}.") - logger.trace(f"{asdict(storage_info)}") - - return storage_info - - -class StorageProviderMock(StorageProvider): - def __init__(self, storage, storage_info): - self.storage = storage - self.storage_info = storage_info - - def __call__(self, *args, **kwargs): - return self.storage, self.storage_info diff --git a/pyinfra/storage/storages/azure.py b/pyinfra/storage/storages/azure.py index 4c7467c..1e56630 100644 --- a/pyinfra/storage/storages/azure.py +++ b/pyinfra/storage/storages/azure.py @@ -15,47 +15,52 @@ logging.getLogger("urllib3").setLevel(logging.WARNING) class AzureStorage(Storage): - def __init__(self, client: BlobServiceClient): + def __init__(self, client: BlobServiceClient, bucket: str): self._client: BlobServiceClient = client + self._bucket = bucket - def has_bucket(self, bucket_name): - container_client = self._client.get_container_client(bucket_name) + @property + def bucket(self): + return self._bucket + + def has_bucket(self): + container_client = self._client.get_container_client(self.bucket) return container_client.exists() - def make_bucket(self, bucket_name): - container_client = self._client.get_container_client(bucket_name) - container_client if container_client.exists() else self._client.create_container(bucket_name) + def make_bucket(self): + container_client = self._client.get_container_client(self.bucket) + container_client if container_client.exists() else self._client.create_container(self.bucket) - def __provide_container_client(self, bucket_name) -> ContainerClient: - self.make_bucket(bucket_name) - container_client = self._client.get_container_client(bucket_name) + def __provide_container_client(self) -> ContainerClient: + self.make_bucket() + container_client = self._client.get_container_client(self.bucket) return container_client - def put_object(self, bucket_name, object_name, data): + def put_object(self, object_name, data): logger.debug(f"Uploading '{object_name}'...") - container_client = self.__provide_container_client(bucket_name) + container_client = self.__provide_container_client() blob_client = container_client.get_blob_client(object_name) blob_client.upload_blob(data, overwrite=True) - def exists(self, bucket_name, object_name): - container_client = self.__provide_container_client(bucket_name) + def exists(self, object_name): + container_client = self.__provide_container_client() blob_client = container_client.get_blob_client(object_name) return blob_client.exists() @retry(tries=3, delay=5, jitter=(1, 3)) - def get_object(self, bucket_name, object_name): + def get_object(self, object_name): logger.debug(f"Downloading '{object_name}'...") try: - container_client = self.__provide_container_client(bucket_name) + container_client = self.__provide_container_client() blob_client = container_client.get_blob_client(object_name) blob_data = blob_client.download_blob() return blob_data.readall() except Exception as err: raise Exception("Failed getting object from azure client") from err - def get_all_objects(self, bucket_name): - container_client = self.__provide_container_client(bucket_name) + def get_all_objects(self): + container_client = self.__provide_container_client() blobs = container_client.list_blobs() for blob in blobs: logger.debug(f"Downloading '{blob.name}'...") @@ -64,18 +69,22 @@ class AzureStorage(Storage): data = blob_data.readall() yield data - def clear_bucket(self, bucket_name): - logger.debug(f"Clearing Azure container '{bucket_name}'...") - container_client = self._client.get_container_client(bucket_name) + def clear_bucket(self): + logger.debug(f"Clearing Azure container '{self.bucket}'...") + container_client = self._client.get_container_client(self.bucket) blobs = container_client.list_blobs() container_client.delete_blobs(*blobs) - def get_all_object_names(self, bucket_name): - container_client = self.__provide_container_client(bucket_name) + def get_all_object_names(self): + container_client = self.__provide_container_client() blobs = container_client.list_blobs() - return zip(repeat(bucket_name), map(attrgetter("name"), blobs)) + return zip(repeat(self.bucket), map(attrgetter("name"), blobs)) def get_azure_storage_from_settings(settings: Dynaconf): validate_settings(settings, azure_storage_validators) - return AzureStorage(BlobServiceClient.from_connection_string(conn_str=settings.storage.azure.connection_string)) + + return AzureStorage( + client=BlobServiceClient.from_connection_string(conn_str=settings.storage.azure.connection_string), + bucket=settings.storage.azure.container, + ) diff --git a/pyinfra/storage/storages/interface.py b/pyinfra/storage/storages/interface.py index f5530d6..6283f13 100644 --- a/pyinfra/storage/storages/interface.py +++ b/pyinfra/storage/storages/interface.py @@ -2,34 +2,39 @@ from abc import ABC, abstractmethod class Storage(ABC): + @property @abstractmethod - def make_bucket(self, bucket_name): + def bucket(self): raise NotImplementedError @abstractmethod - def has_bucket(self, bucket_name): + def make_bucket(self): raise NotImplementedError @abstractmethod - def put_object(self, bucket_name, object_name, data): + def has_bucket(self): raise NotImplementedError @abstractmethod - def exists(self, bucket_name, object_name): + def put_object(self, object_name, data): raise NotImplementedError @abstractmethod - def get_object(self, bucket_name, object_name): + def exists(self, object_name): raise NotImplementedError @abstractmethod - def get_all_objects(self, bucket_name): + def get_object(self, object_name): raise NotImplementedError @abstractmethod - def clear_bucket(self, bucket_name): + def get_all_objects(self): raise NotImplementedError @abstractmethod - def get_all_object_names(self, bucket_name): + def clear_bucket(self): + raise NotImplementedError + + @abstractmethod + def get_all_object_names(self): raise NotImplementedError diff --git a/pyinfra/storage/storages/mock.py b/pyinfra/storage/storages/mock.py index b209399..2a77bca 100644 --- a/pyinfra/storage/storages/mock.py +++ b/pyinfra/storage/storages/mock.py @@ -5,32 +5,35 @@ class StorageMock(Storage): def __init__(self, data: bytes = None, file_name: str = None, bucket: str = None): self.data = data self.file_name = file_name - self.bucket = bucket + self._bucket = bucket - def make_bucket(self, bucket_name): - self.bucket = bucket_name + @property + def bucket(self): + return self._bucket - def has_bucket(self, bucket_name): - return self.bucket == bucket_name + def make_bucket(self): + pass - def put_object(self, bucket_name, object_name, data): - self.bucket = bucket_name + def has_bucket(self): + return True + + def put_object(self, object_name, data): self.file_name = object_name self.data = data - def exists(self, bucket_name, object_name): - return self.bucket == bucket_name and self.file_name == object_name + def exists(self, object_name): + return self.file_name == object_name - def get_object(self, bucket_name, object_name): + def get_object(self, object_name): return self.data - def get_all_objects(self, bucket_name): + def get_all_objects(self): raise NotImplementedError - def clear_bucket(self, bucket_name): - self.bucket = None + def clear_bucket(self): + self._bucket = None self.file_name = None self.data = None - def get_all_object_names(self, bucket_name): + def get_all_object_names(self): raise NotImplementedError diff --git a/pyinfra/storage/storages/s3.py b/pyinfra/storage/storages/s3.py index 57c8ac0..dcc151d 100644 --- a/pyinfra/storage/storages/s3.py +++ b/pyinfra/storage/storages/s3.py @@ -13,35 +13,40 @@ from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint class S3Storage(Storage): - def __init__(self, client: Minio): + def __init__(self, client: Minio, bucket: str): self._client = client + self._bucket = bucket - def make_bucket(self, bucket_name): - if not self.has_bucket(bucket_name): - self._client.make_bucket(bucket_name) + @property + def bucket(self): + return self._bucket - def has_bucket(self, bucket_name): - return self._client.bucket_exists(bucket_name) + def make_bucket(self): + if not self.has_bucket(): + self._client.make_bucket(self.bucket) - def put_object(self, bucket_name, object_name, data): + def has_bucket(self): + return self._client.bucket_exists(self.bucket) + + def put_object(self, object_name, data): logger.debug(f"Uploading '{object_name}'...") data = io.BytesIO(data) - self._client.put_object(bucket_name, object_name, data, length=data.getbuffer().nbytes) + self._client.put_object(self.bucket, object_name, data, length=data.getbuffer().nbytes) - def exists(self, bucket_name, object_name): + def exists(self, object_name): try: - self._client.stat_object(bucket_name, object_name) + self._client.stat_object(self.bucket, object_name) return True except Exception: return False @retry(tries=3, delay=5, jitter=(1, 3)) - def get_object(self, bucket_name, object_name): + def get_object(self, object_name): logger.debug(f"Downloading '{object_name}'...") response = None try: - response = self._client.get_object(bucket_name, object_name) + response = self._client.get_object(self.bucket, object_name) return response.data except Exception as err: raise Exception("Failed getting object from s3 client") from err @@ -50,20 +55,20 @@ class S3Storage(Storage): response.close() response.release_conn() - def get_all_objects(self, bucket_name): - for obj in self._client.list_objects(bucket_name, recursive=True): + def get_all_objects(self): + for obj in self._client.list_objects(self.bucket, recursive=True): logger.debug(f"Downloading '{obj.object_name}'...") - yield self.get_object(bucket_name, obj.object_name) + yield self.get_object(obj.object_name) - def clear_bucket(self, bucket_name): - logger.debug(f"Clearing S3 bucket '{bucket_name}'...") - objects = self._client.list_objects(bucket_name, recursive=True) + def clear_bucket(self): + logger.debug(f"Clearing S3 bucket '{self.bucket}'...") + objects = self._client.list_objects(self.bucket, recursive=True) for obj in objects: - self._client.remove_object(bucket_name, obj.object_name) + self._client.remove_object(self.bucket, obj.object_name) - def get_all_object_names(self, bucket_name): - objs = self._client.list_objects(bucket_name, recursive=True) - return zip(repeat(bucket_name), map(attrgetter("object_name"), objs)) + def get_all_object_names(self): + objs = self._client.list_objects(self.bucket, recursive=True) + return zip(repeat(self.bucket), map(attrgetter("object_name"), objs)) def get_s3_storage_from_settings(settings: Dynaconf): @@ -72,11 +77,12 @@ def get_s3_storage_from_settings(settings: Dynaconf): secure, endpoint = validate_and_parse_s3_endpoint(settings.storage.s3.endpoint) return S3Storage( - Minio( + client=Minio( secure=secure, endpoint=endpoint, access_key=settings.storage.s3.key, secret_key=settings.storage.s3.secret, region=settings.storage.s3.region, - ) + ), + bucket=settings.storage.s3.bucket, ) diff --git a/pyinfra/utils/config_validation.py b/pyinfra/utils/config_validation.py index 028f5d4..629891d 100644 --- a/pyinfra/utils/config_validation.py +++ b/pyinfra/utils/config_validation.py @@ -15,6 +15,7 @@ queue_manager_validators = [ azure_storage_validators = [ Validator("storage.azure.connection_string", must_exist=True), + Validator("storage.azure.container", must_exist=True), ] s3_storage_validators = [ @@ -22,12 +23,19 @@ s3_storage_validators = [ Validator("storage.s3.key", must_exist=True), Validator("storage.s3.secret", must_exist=True), Validator("storage.s3.region", must_exist=True), + Validator("storage.s3.bucket", must_exist=True), ] storage_validators = [ Validator("storage.backend", must_exist=True), ] +multi_tenant_storage_validators = [ + Validator("storage.tenant_server.endpoint", must_exist=True), + Validator("storage.tenant_server.public_key", must_exist=True), +] + + prometheus_validators = [ Validator("metrics.prometheus.prefix", must_exist=True), Validator("metrics.prometheus.enabled", must_exist=True), diff --git a/pyinfra/webserver.py b/pyinfra/webserver.py index acecf24..0d37139 100644 --- a/pyinfra/webserver.py +++ b/pyinfra/webserver.py @@ -8,11 +8,11 @@ from fastapi import FastAPI from pyinfra.utils.config_validation import validate_settings, webserver_validators -def create_webserver_thread(app: FastAPI, settings: Dynaconf) -> threading.Thread: +def create_webserver_thread_from_settings(app: FastAPI, settings: Dynaconf) -> threading.Thread: validate_settings(settings, validators=webserver_validators) - return threading.Thread( - target=lambda: uvicorn.run( - app, port=settings.webserver.port, host=settings.webserver.host, log_level=logging.WARNING - ) - ) + return create_webserver_thread(app=app, port=settings.webserver.port, host=settings.webserver.host) + + +def create_webserver_thread(app: FastAPI, port: int, host: str) -> threading.Thread: + return threading.Thread(target=lambda: uvicorn.run(app, port=port, host=host, log_level=logging.WARNING)) diff --git a/tests/tests_with_docker_compose/prometheus_monitoring_test.py b/tests/tests_with_docker_compose/prometheus_monitoring_test.py index a6d0df1..75da089 100644 --- a/tests/tests_with_docker_compose/prometheus_monitoring_test.py +++ b/tests/tests_with_docker_compose/prometheus_monitoring_test.py @@ -6,14 +6,14 @@ import requests from fastapi import FastAPI from pyinfra.monitor.prometheus import add_prometheus_endpoint, make_prometheus_processing_time_decorator_from_settings -from pyinfra.webserver import create_webserver_thread +from pyinfra.webserver import create_webserver_thread_from_settings @pytest.fixture(scope="class") def app_with_prometheus_endpoint(settings): app = FastAPI() app = add_prometheus_endpoint(app) - thread = create_webserver_thread(app, settings) + thread = create_webserver_thread_from_settings(app, settings) thread.daemon = True thread.start() sleep(1) diff --git a/tests/tests_with_docker_compose/storage_test.py b/tests/tests_with_docker_compose/storage_test.py index 6f0bbd6..1f97468 100644 --- a/tests/tests_with_docker_compose/storage_test.py +++ b/tests/tests_with_docker_compose/storage_test.py @@ -1,64 +1,119 @@ +from time import sleep + import pytest +from fastapi import FastAPI -from pyinfra.storage.storage import get_storage_from_settings +from pyinfra.storage.connection import get_storage_from_settings, get_storage_from_tenant_id +from pyinfra.utils.cipher import encrypt +from pyinfra.webserver import create_webserver_thread -@pytest.fixture(scope="session") -def storage(storage_backend, bucket_name, settings): +@pytest.fixture(scope="class") +def storage(storage_backend, settings): settings.storage.backend = storage_backend storage = get_storage_from_settings(settings) - storage.make_bucket(bucket_name) + storage.make_bucket() yield storage - storage.clear_bucket(bucket_name) + storage.clear_bucket() -@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="session") -@pytest.mark.parametrize("bucket_name", ["bucket"], scope="session") +@pytest.fixture(scope="class") +def tenant_server_mock(settings, tenant_server_host, tenant_server_port): + app = FastAPI() + + @app.get("/azure_tenant") + def get_azure_storage_info(): + return { + "azureStorageConnection": { + "connectionString": encrypt( + settings.storage.tenant_server.public_key, settings.storage.azure.connection_string + ), + "containerName": settings.storage.azure.container, + } + } + + @app.get("/s3_tenant") + def get_s3_storage_info(): + return { + "s3StorageConnection": { + "endpoint": settings.storage.s3.endpoint, + "key": settings.storage.s3.key, + "secret": encrypt(settings.storage.tenant_server.public_key, settings.storage.s3.secret), + "region": settings.storage.s3.region, + "bucketName": settings.storage.s3.bucket, + } + } + + thread = create_webserver_thread(app, tenant_server_port, tenant_server_host) + thread.daemon = True + thread.start() + sleep(1) + yield + thread.join(timeout=1) + + +@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class") class TestStorage: - def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name): - storage.clear_bucket(bucket_name) - data_received = storage.get_all_objects(bucket_name) + def test_clearing_bucket_yields_empty_bucket(self, storage): + storage.clear_bucket() + data_received = storage.get_all_objects() assert not {*data_received} - def test_getting_object_put_in_bucket_is_object(self, storage, bucket_name): - storage.clear_bucket(bucket_name) - storage.put_object(bucket_name, "file", b"content") - data_received = storage.get_object(bucket_name, "file") + def test_getting_object_put_in_bucket_is_object(self, storage): + storage.clear_bucket() + storage.put_object("file", b"content") + data_received = storage.get_object("file") assert b"content" == data_received - def test_object_put_in_bucket_exists_on_storage(self, storage, bucket_name): - storage.clear_bucket(bucket_name) - storage.put_object(bucket_name, "file", b"content") - assert storage.exists(bucket_name, "file") + def test_object_put_in_bucket_exists_on_storage(self, storage): + storage.clear_bucket() + storage.put_object("file", b"content") + assert storage.exists("file") - def test_getting_nested_object_put_in_bucket_is_nested_object(self, storage, bucket_name): - storage.clear_bucket(bucket_name) - storage.put_object(bucket_name, "folder/file", b"content") - data_received = storage.get_object(bucket_name, "folder/file") + def test_getting_nested_object_put_in_bucket_is_nested_object(self, storage): + storage.clear_bucket() + storage.put_object("folder/file", b"content") + data_received = storage.get_object("folder/file") assert b"content" == data_received - def test_getting_objects_put_in_bucket_are_objects(self, storage, bucket_name): - storage.clear_bucket(bucket_name) - storage.put_object(bucket_name, "file1", b"content 1") - storage.put_object(bucket_name, "folder/file2", b"content 2") - data_received = storage.get_all_objects(bucket_name) + def test_getting_objects_put_in_bucket_are_objects(self, storage): + storage.clear_bucket() + storage.put_object("file1", b"content 1") + storage.put_object("folder/file2", b"content 2") + data_received = storage.get_all_objects() assert {b"content 1", b"content 2"} == {*data_received} - def test_make_bucket_produces_bucket(self, storage, bucket_name): - storage.clear_bucket(bucket_name) - storage.make_bucket(bucket_name) - assert storage.has_bucket(bucket_name) + def test_make_bucket_produces_bucket(self, storage): + storage.clear_bucket() + storage.make_bucket() + assert storage.has_bucket() - def test_listing_bucket_files_yields_all_files_in_bucket(self, storage, bucket_name): - storage.clear_bucket(bucket_name) - storage.put_object(bucket_name, "file1", b"content 1") - storage.put_object(bucket_name, "file2", b"content 2") - full_names_received = storage.get_all_object_names(bucket_name) - assert {(bucket_name, "file1"), (bucket_name, "file2")} == {*full_names_received} + def test_listing_bucket_files_yields_all_files_in_bucket(self, storage): + storage.clear_bucket() + storage.put_object("file1", b"content 1") + storage.put_object("file2", b"content 2") + full_names_received = storage.get_all_object_names() + assert {(storage.bucket, "file1"), (storage.bucket, "file2")} == {*full_names_received} - def test_data_loading_failure_raised_if_object_not_present(self, storage, bucket_name): - storage.clear_bucket(bucket_name) + def test_data_loading_failure_raised_if_object_not_present(self, storage): + storage.clear_bucket() with pytest.raises(Exception): - storage.get_object(bucket_name, "folder/file") + storage.get_object("folder/file") + + +@pytest.mark.parametrize("tenant_id", ["azure_tenant", "s3_tenant"], scope="class") +@pytest.mark.parametrize("tenant_server_host", ["localhost"], scope="class") +@pytest.mark.parametrize("tenant_server_port", [8000], scope="class") +class TestMultiTenantStorage: + def test_storage_connection_from_tenant_id( + self, tenant_id, tenant_server_mock, settings, tenant_server_host, tenant_server_port + ): + settings["storage"]["tenant_server"]["endpoint"] = f"http://{tenant_server_host}:{tenant_server_port}" + storage = get_storage_from_tenant_id(tenant_id, settings) + + storage.put_object("file", b"content") + data_received = storage.get_object("file") + + assert b"content" == data_received