refactor: multi tenant storage connection
This commit is contained in:
parent
17c5eebdf6
commit
ec5ad09fa8
@ -1,57 +0,0 @@
|
||||
from funcy import identity
|
||||
from operator import attrgetter
|
||||
from prometheus_client import Summary, start_http_server, CollectorRegistry
|
||||
from time import time
|
||||
from typing import Callable, Any, Sized
|
||||
|
||||
from pyinfra.config import Config
|
||||
|
||||
|
||||
class PrometheusMonitor:
|
||||
def __init__(self, prefix: str, host: str, port: int):
|
||||
"""Register the monitoring metrics and start a webserver where they can be scraped at the endpoint
|
||||
http://{host}:{port}/prometheus
|
||||
|
||||
Args:
|
||||
prefix: should per convention consist of {product_name}_{service_name}_{parameter_to_monitor}
|
||||
parameter_to_monitor is defined by the result of the processing service.
|
||||
"""
|
||||
self.registry = CollectorRegistry()
|
||||
|
||||
self.entity_processing_time_sum = Summary(
|
||||
f"{prefix}_processing_time", "Summed up average processing time per entity observed", registry=self.registry
|
||||
)
|
||||
|
||||
start_http_server(port, host, self.registry)
|
||||
|
||||
def __call__(self, process_fn: Callable) -> Callable:
|
||||
"""Monitor the runtime of a function and update the registered metric with the average runtime per resulting
|
||||
element.
|
||||
"""
|
||||
return self._add_result_monitoring(process_fn)
|
||||
|
||||
def _add_result_monitoring(self, process_fn: Callable):
|
||||
def inner(data: Any, **kwargs):
|
||||
start = time()
|
||||
|
||||
result: Sized = process_fn(data, **kwargs)
|
||||
|
||||
runtime = time() - start
|
||||
|
||||
if not result:
|
||||
return result
|
||||
|
||||
processing_time_per_entity = runtime / len(result)
|
||||
|
||||
self.entity_processing_time_sum.observe(processing_time_per_entity)
|
||||
|
||||
return result
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def get_monitor_from_config(config: Config) -> Callable:
|
||||
if config.monitoring_enabled:
|
||||
return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_host", "prometheus_port")(config))
|
||||
else:
|
||||
return identity
|
||||
124
pyinfra/storage/connection.py
Normal file
124
pyinfra/storage/connection.py
Normal file
@ -0,0 +1,124 @@
|
||||
from functools import lru_cache, partial
|
||||
from typing import Callable
|
||||
|
||||
import requests
|
||||
from dynaconf import Dynaconf
|
||||
from funcy import compose
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.storage.storages.azure import get_azure_storage_from_settings
|
||||
from pyinfra.storage.storages.interface import Storage
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
|
||||
from pyinfra.utils.cipher import decrypt
|
||||
from pyinfra.utils.compressing import get_decompressor, get_compressor
|
||||
from pyinfra.utils.config_validation import validate_settings, storage_validators, multi_tenant_storage_validators
|
||||
from pyinfra.utils.encoding import get_decoder, get_encoder
|
||||
|
||||
|
||||
def get_storage(settings: Dynaconf, tenant_id: str = None) -> Storage:
|
||||
"""Get storage connection based on settings.
|
||||
If tenant_id is provided, gets storage connection information from tenant server instead.
|
||||
The connections are cached based on the settings.cache_size value.
|
||||
|
||||
In the future, when the default storage from config is no longer needed (only multi-tenant storage will be used),
|
||||
get_storage_from_tenant_id can replace this function directly.
|
||||
"""
|
||||
if tenant_id:
|
||||
logger.info(f"Using tenant storage for {tenant_id}.")
|
||||
return get_storage_from_tenant_id(tenant_id, settings)
|
||||
else:
|
||||
logger.info("Using default storage.")
|
||||
return get_storage_from_settings(settings)
|
||||
|
||||
|
||||
def get_storage_from_settings(settings: Dynaconf) -> Storage:
|
||||
validate_settings(settings, storage_validators)
|
||||
|
||||
@lru_cache(maxsize=settings.storage.cache_size)
|
||||
def _get_storage(backend: str) -> Storage:
|
||||
return storage_dispatcher[backend](settings)
|
||||
|
||||
return _get_storage(settings.storage.backend)
|
||||
|
||||
|
||||
def get_storage_from_tenant_id(tenant_id: str, settings: Dynaconf) -> Storage:
|
||||
validate_settings(settings, multi_tenant_storage_validators)
|
||||
|
||||
@lru_cache(maxsize=settings.storage.cache_size)
|
||||
def _get_storage(tenant: str, endpoint: str, public_key: str) -> Storage:
|
||||
response = requests.get(f"{endpoint}/{tenant}").json()
|
||||
|
||||
maybe_azure = response.get("azureStorageConnection")
|
||||
maybe_s3 = response.get("s3StorageConnection")
|
||||
assert (maybe_azure or maybe_s3) and not (maybe_azure and maybe_s3), "Only one storage backend can be used."
|
||||
|
||||
if maybe_azure:
|
||||
connection_string = decrypt(public_key, maybe_azure["connectionString"])
|
||||
backend = "azure"
|
||||
storage_settings = {
|
||||
"storage": {
|
||||
"azure": {
|
||||
"connection_string": connection_string,
|
||||
"container": maybe_azure["containerName"],
|
||||
},
|
||||
}
|
||||
}
|
||||
elif maybe_s3:
|
||||
secret = decrypt(public_key, maybe_s3["secret"])
|
||||
backend = "s3"
|
||||
storage_settings = {
|
||||
"storage": {
|
||||
"s3": {
|
||||
"endpoint": maybe_s3["endpoint"],
|
||||
"key": maybe_s3["key"],
|
||||
"secret": secret,
|
||||
"region": maybe_s3["region"],
|
||||
"bucket": maybe_s3["bucketName"],
|
||||
},
|
||||
}
|
||||
}
|
||||
else:
|
||||
raise Exception(f"Unknown storage backend in {response}.")
|
||||
|
||||
storage_settings = Dynaconf()
|
||||
storage_settings.update(settings)
|
||||
|
||||
storage = storage_dispatcher[backend](storage_settings)
|
||||
|
||||
return storage
|
||||
|
||||
return _get_storage(tenant_id, settings.storage.tenant_server.endpoint, settings.storage.tenant_server.public_key)
|
||||
|
||||
|
||||
storage_dispatcher = {
|
||||
"azure": get_azure_storage_from_settings,
|
||||
"s3": get_s3_storage_from_settings,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def make_downloader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable:
|
||||
verify = partial(verify_existence, storage, bucket)
|
||||
download = partial(storage.get_object, bucket)
|
||||
decompress = get_decompressor(compression_type)
|
||||
decode = get_decoder(file_type)
|
||||
|
||||
return compose(decode, decompress, download, verify)
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def make_uploader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable:
|
||||
upload = partial(storage.put_object, bucket)
|
||||
compress = get_compressor(compression_type)
|
||||
encode = get_encoder(file_type)
|
||||
|
||||
def inner(file_name, file_bytes):
|
||||
upload(file_name, compose(compress, encode)(file_bytes))
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def verify_existence(storage: Storage, bucket: str, file_name: str) -> str:
|
||||
if not storage.exists(file_name):
|
||||
raise FileNotFoundError(f"{file_name=} name not found on storage in {storage.bucket=}.")
|
||||
return file_name
|
||||
@ -1,51 +0,0 @@
|
||||
from functools import lru_cache, partial
|
||||
from typing import Callable
|
||||
|
||||
from dynaconf import Dynaconf
|
||||
from funcy import compose
|
||||
|
||||
from pyinfra.storage.storages.interface import Storage
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
|
||||
from pyinfra.utils.compressing import get_decompressor, get_compressor
|
||||
from pyinfra.utils.config_validation import validate_settings, storage_validators
|
||||
from pyinfra.utils.encoding import get_decoder, get_encoder
|
||||
|
||||
|
||||
def get_storage_from_settings(settings: Dynaconf) -> Storage:
|
||||
validate_settings(settings, storage_validators)
|
||||
|
||||
return storage_dispatcher[settings.storage.backend](settings)
|
||||
|
||||
|
||||
storage_dispatcher = {
|
||||
"azure": get_s3_storage_from_settings,
|
||||
"s3": get_s3_storage_from_settings,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def make_downloader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable:
|
||||
verify = partial(verify_existence, storage, bucket)
|
||||
download = partial(storage.get_object, bucket)
|
||||
decompress = get_decompressor(compression_type)
|
||||
decode = get_decoder(file_type)
|
||||
|
||||
return compose(decode, decompress, download, verify)
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def make_uploader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable:
|
||||
upload = partial(storage.put_object, bucket)
|
||||
compress = get_compressor(compression_type)
|
||||
encode = get_encoder(file_type)
|
||||
|
||||
def inner(file_name, file_bytes):
|
||||
upload(file_name, compose(compress, encode)(file_bytes))
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def verify_existence(storage: Storage, bucket: str, file_name: str) -> str:
|
||||
if not storage.exists(bucket, file_name):
|
||||
raise FileNotFoundError(f"{file_name=} name not found on storage in {bucket=}.")
|
||||
return file_name
|
||||
@ -1,125 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import requests
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from minio import Minio
|
||||
|
||||
from pyinfra.config import Config
|
||||
from pyinfra.exception import UnknownStorageBackend
|
||||
from pyinfra.storage.storages.azure import AzureStorage
|
||||
from pyinfra.storage.storages.interface import Storage
|
||||
from pyinfra.storage.storages.s3 import S3Storage
|
||||
from pyinfra.utils.cipher import decrypt
|
||||
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageInfo:
|
||||
bucket_name: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AzureStorageInfo(StorageInfo):
|
||||
connection_string: str
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.connection_string)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, AzureStorageInfo):
|
||||
return False
|
||||
return self.connection_string == other.connection_string
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class S3StorageInfo(StorageInfo):
|
||||
secure: bool
|
||||
endpoint: str
|
||||
access_key: str
|
||||
secret_key: str
|
||||
region: str
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.secure, self.endpoint, self.access_key, self.secret_key, self.region))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, S3StorageInfo):
|
||||
return False
|
||||
return (
|
||||
self.secure == other.secure
|
||||
and self.endpoint == other.endpoint
|
||||
and self.access_key == other.access_key
|
||||
and self.secret_key == other.secret_key
|
||||
and self.region == other.region
|
||||
)
|
||||
|
||||
|
||||
def get_storage_from_storage_info(storage_info: StorageInfo) -> Storage:
|
||||
if isinstance(storage_info, AzureStorageInfo):
|
||||
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string))
|
||||
elif isinstance(storage_info, S3StorageInfo):
|
||||
return S3Storage(
|
||||
Minio(
|
||||
secure=storage_info.secure,
|
||||
endpoint=storage_info.endpoint,
|
||||
access_key=storage_info.access_key,
|
||||
secret_key=storage_info.secret_key,
|
||||
region=storage_info.region,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise UnknownStorageBackend()
|
||||
|
||||
|
||||
def get_storage_info_from_endpoint(public_key: str, endpoint: str, x_tenant_id: str) -> StorageInfo:
|
||||
resp = requests.get(f"{endpoint}/{x_tenant_id}").json()
|
||||
|
||||
maybe_azure = resp.get("azureStorageConnection")
|
||||
maybe_s3 = resp.get("s3StorageConnection")
|
||||
assert not (maybe_azure and maybe_s3)
|
||||
|
||||
if maybe_azure:
|
||||
connection_string = decrypt(public_key, maybe_azure["connectionString"])
|
||||
storage_info = AzureStorageInfo(
|
||||
connection_string=connection_string,
|
||||
bucket_name=maybe_azure["containerName"],
|
||||
)
|
||||
elif maybe_s3:
|
||||
secure, endpoint = validate_and_parse_s3_endpoint(maybe_s3["endpoint"])
|
||||
secret = decrypt(public_key, maybe_s3["secret"])
|
||||
|
||||
storage_info = S3StorageInfo(
|
||||
secure=secure,
|
||||
endpoint=endpoint,
|
||||
access_key=maybe_s3["key"],
|
||||
secret_key=secret,
|
||||
region=maybe_s3["region"],
|
||||
bucket_name=maybe_s3["bucketName"],
|
||||
)
|
||||
else:
|
||||
raise UnknownStorageBackend()
|
||||
|
||||
return storage_info
|
||||
|
||||
|
||||
def get_storage_info_from_config(config: Config) -> StorageInfo:
|
||||
if config.storage_backend == "s3":
|
||||
storage_info = S3StorageInfo(
|
||||
secure=config.storage_secure_connection,
|
||||
endpoint=config.storage_endpoint,
|
||||
access_key=config.storage_key,
|
||||
secret_key=config.storage_secret,
|
||||
region=config.storage_region,
|
||||
bucket_name=config.storage_bucket,
|
||||
)
|
||||
|
||||
elif config.storage_backend == "azure":
|
||||
storage_info = AzureStorageInfo(
|
||||
connection_string=config.storage_azureconnectionstring,
|
||||
bucket_name=config.storage_bucket,
|
||||
)
|
||||
|
||||
else:
|
||||
raise UnknownStorageBackend(f"Unknown storage backend '{config.storage_backend}'.")
|
||||
|
||||
return storage_info
|
||||
@ -1,55 +0,0 @@
|
||||
from dataclasses import asdict
|
||||
from functools import partial, lru_cache
|
||||
from kn_utils.logging import logger
|
||||
from typing import Tuple
|
||||
|
||||
from pyinfra.config import Config
|
||||
from pyinfra.storage.storage_info import (
|
||||
get_storage_info_from_config,
|
||||
get_storage_info_from_endpoint,
|
||||
StorageInfo,
|
||||
get_storage_from_storage_info,
|
||||
)
|
||||
from pyinfra.storage.storages.interface import Storage
|
||||
|
||||
|
||||
class StorageProvider:
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.default_storage_info: StorageInfo = get_storage_info_from_config(config)
|
||||
|
||||
self.get_storage_info_from_tenant_id = partial(
|
||||
get_storage_info_from_endpoint,
|
||||
config.tenant_decryption_public_key,
|
||||
config.tenant_endpoint,
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._connect(*args, **kwargs)
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _connect(self, x_tenant_id=None) -> Tuple[Storage, StorageInfo]:
|
||||
storage_info = self._get_storage_info(x_tenant_id)
|
||||
storage_connection = get_storage_from_storage_info(storage_info)
|
||||
return storage_connection, storage_info
|
||||
|
||||
def _get_storage_info(self, x_tenant_id=None):
|
||||
if x_tenant_id:
|
||||
storage_info = self.get_storage_info_from_tenant_id(x_tenant_id)
|
||||
logger.debug(f"Received {storage_info.__class__.__name__} for {x_tenant_id} from endpoint.")
|
||||
logger.trace(f"{asdict(storage_info)}")
|
||||
else:
|
||||
storage_info = self.default_storage_info
|
||||
logger.debug(f"Using local default {storage_info.__class__.__name__} for {x_tenant_id}.")
|
||||
logger.trace(f"{asdict(storage_info)}")
|
||||
|
||||
return storage_info
|
||||
|
||||
|
||||
class StorageProviderMock(StorageProvider):
|
||||
def __init__(self, storage, storage_info):
|
||||
self.storage = storage
|
||||
self.storage_info = storage_info
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.storage, self.storage_info
|
||||
@ -15,47 +15,52 @@ logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class AzureStorage(Storage):
|
||||
def __init__(self, client: BlobServiceClient):
|
||||
def __init__(self, client: BlobServiceClient, bucket: str):
|
||||
self._client: BlobServiceClient = client
|
||||
self._bucket = bucket
|
||||
|
||||
def has_bucket(self, bucket_name):
|
||||
container_client = self._client.get_container_client(bucket_name)
|
||||
@property
|
||||
def bucket(self):
|
||||
return self._bucket
|
||||
|
||||
def has_bucket(self):
|
||||
container_client = self._client.get_container_client(self.bucket)
|
||||
return container_client.exists()
|
||||
|
||||
def make_bucket(self, bucket_name):
|
||||
container_client = self._client.get_container_client(bucket_name)
|
||||
container_client if container_client.exists() else self._client.create_container(bucket_name)
|
||||
def make_bucket(self):
|
||||
container_client = self._client.get_container_client(self.bucket)
|
||||
container_client if container_client.exists() else self._client.create_container(self.bucket)
|
||||
|
||||
def __provide_container_client(self, bucket_name) -> ContainerClient:
|
||||
self.make_bucket(bucket_name)
|
||||
container_client = self._client.get_container_client(bucket_name)
|
||||
def __provide_container_client(self) -> ContainerClient:
|
||||
self.make_bucket()
|
||||
container_client = self._client.get_container_client(self.bucket)
|
||||
return container_client
|
||||
|
||||
def put_object(self, bucket_name, object_name, data):
|
||||
def put_object(self, object_name, data):
|
||||
logger.debug(f"Uploading '{object_name}'...")
|
||||
container_client = self.__provide_container_client(bucket_name)
|
||||
container_client = self.__provide_container_client()
|
||||
blob_client = container_client.get_blob_client(object_name)
|
||||
blob_client.upload_blob(data, overwrite=True)
|
||||
|
||||
def exists(self, bucket_name, object_name):
|
||||
container_client = self.__provide_container_client(bucket_name)
|
||||
def exists(self, object_name):
|
||||
container_client = self.__provide_container_client()
|
||||
blob_client = container_client.get_blob_client(object_name)
|
||||
return blob_client.exists()
|
||||
|
||||
@retry(tries=3, delay=5, jitter=(1, 3))
|
||||
def get_object(self, bucket_name, object_name):
|
||||
def get_object(self, object_name):
|
||||
logger.debug(f"Downloading '{object_name}'...")
|
||||
|
||||
try:
|
||||
container_client = self.__provide_container_client(bucket_name)
|
||||
container_client = self.__provide_container_client()
|
||||
blob_client = container_client.get_blob_client(object_name)
|
||||
blob_data = blob_client.download_blob()
|
||||
return blob_data.readall()
|
||||
except Exception as err:
|
||||
raise Exception("Failed getting object from azure client") from err
|
||||
|
||||
def get_all_objects(self, bucket_name):
|
||||
container_client = self.__provide_container_client(bucket_name)
|
||||
def get_all_objects(self):
|
||||
container_client = self.__provide_container_client()
|
||||
blobs = container_client.list_blobs()
|
||||
for blob in blobs:
|
||||
logger.debug(f"Downloading '{blob.name}'...")
|
||||
@ -64,18 +69,22 @@ class AzureStorage(Storage):
|
||||
data = blob_data.readall()
|
||||
yield data
|
||||
|
||||
def clear_bucket(self, bucket_name):
|
||||
logger.debug(f"Clearing Azure container '{bucket_name}'...")
|
||||
container_client = self._client.get_container_client(bucket_name)
|
||||
def clear_bucket(self):
|
||||
logger.debug(f"Clearing Azure container '{self.bucket}'...")
|
||||
container_client = self._client.get_container_client(self.bucket)
|
||||
blobs = container_client.list_blobs()
|
||||
container_client.delete_blobs(*blobs)
|
||||
|
||||
def get_all_object_names(self, bucket_name):
|
||||
container_client = self.__provide_container_client(bucket_name)
|
||||
def get_all_object_names(self):
|
||||
container_client = self.__provide_container_client()
|
||||
blobs = container_client.list_blobs()
|
||||
return zip(repeat(bucket_name), map(attrgetter("name"), blobs))
|
||||
return zip(repeat(self.bucket), map(attrgetter("name"), blobs))
|
||||
|
||||
|
||||
def get_azure_storage_from_settings(settings: Dynaconf):
|
||||
validate_settings(settings, azure_storage_validators)
|
||||
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=settings.storage.azure.connection_string))
|
||||
|
||||
return AzureStorage(
|
||||
client=BlobServiceClient.from_connection_string(conn_str=settings.storage.azure.connection_string),
|
||||
bucket=settings.storage.azure.container,
|
||||
)
|
||||
|
||||
@ -2,34 +2,39 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Storage(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def make_bucket(self, bucket_name):
|
||||
def bucket(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def has_bucket(self, bucket_name):
|
||||
def make_bucket(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def put_object(self, bucket_name, object_name, data):
|
||||
def has_bucket(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, bucket_name, object_name):
|
||||
def put_object(self, object_name, data):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_object(self, bucket_name, object_name):
|
||||
def exists(self, object_name):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_all_objects(self, bucket_name):
|
||||
def get_object(self, object_name):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear_bucket(self, bucket_name):
|
||||
def get_all_objects(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_all_object_names(self, bucket_name):
|
||||
def clear_bucket(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_all_object_names(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -5,32 +5,35 @@ class StorageMock(Storage):
|
||||
def __init__(self, data: bytes = None, file_name: str = None, bucket: str = None):
|
||||
self.data = data
|
||||
self.file_name = file_name
|
||||
self.bucket = bucket
|
||||
self._bucket = bucket
|
||||
|
||||
def make_bucket(self, bucket_name):
|
||||
self.bucket = bucket_name
|
||||
@property
|
||||
def bucket(self):
|
||||
return self._bucket
|
||||
|
||||
def has_bucket(self, bucket_name):
|
||||
return self.bucket == bucket_name
|
||||
def make_bucket(self):
|
||||
pass
|
||||
|
||||
def put_object(self, bucket_name, object_name, data):
|
||||
self.bucket = bucket_name
|
||||
def has_bucket(self):
|
||||
return True
|
||||
|
||||
def put_object(self, object_name, data):
|
||||
self.file_name = object_name
|
||||
self.data = data
|
||||
|
||||
def exists(self, bucket_name, object_name):
|
||||
return self.bucket == bucket_name and self.file_name == object_name
|
||||
def exists(self, object_name):
|
||||
return self.file_name == object_name
|
||||
|
||||
def get_object(self, bucket_name, object_name):
|
||||
def get_object(self, object_name):
|
||||
return self.data
|
||||
|
||||
def get_all_objects(self, bucket_name):
|
||||
def get_all_objects(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def clear_bucket(self, bucket_name):
|
||||
self.bucket = None
|
||||
def clear_bucket(self):
|
||||
self._bucket = None
|
||||
self.file_name = None
|
||||
self.data = None
|
||||
|
||||
def get_all_object_names(self, bucket_name):
|
||||
def get_all_object_names(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -13,35 +13,40 @@ from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
|
||||
|
||||
|
||||
class S3Storage(Storage):
|
||||
def __init__(self, client: Minio):
|
||||
def __init__(self, client: Minio, bucket: str):
|
||||
self._client = client
|
||||
self._bucket = bucket
|
||||
|
||||
def make_bucket(self, bucket_name):
|
||||
if not self.has_bucket(bucket_name):
|
||||
self._client.make_bucket(bucket_name)
|
||||
@property
|
||||
def bucket(self):
|
||||
return self._bucket
|
||||
|
||||
def has_bucket(self, bucket_name):
|
||||
return self._client.bucket_exists(bucket_name)
|
||||
def make_bucket(self):
|
||||
if not self.has_bucket():
|
||||
self._client.make_bucket(self.bucket)
|
||||
|
||||
def put_object(self, bucket_name, object_name, data):
|
||||
def has_bucket(self):
|
||||
return self._client.bucket_exists(self.bucket)
|
||||
|
||||
def put_object(self, object_name, data):
|
||||
logger.debug(f"Uploading '{object_name}'...")
|
||||
data = io.BytesIO(data)
|
||||
self._client.put_object(bucket_name, object_name, data, length=data.getbuffer().nbytes)
|
||||
self._client.put_object(self.bucket, object_name, data, length=data.getbuffer().nbytes)
|
||||
|
||||
def exists(self, bucket_name, object_name):
|
||||
def exists(self, object_name):
|
||||
try:
|
||||
self._client.stat_object(bucket_name, object_name)
|
||||
self._client.stat_object(self.bucket, object_name)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@retry(tries=3, delay=5, jitter=(1, 3))
|
||||
def get_object(self, bucket_name, object_name):
|
||||
def get_object(self, object_name):
|
||||
logger.debug(f"Downloading '{object_name}'...")
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = self._client.get_object(bucket_name, object_name)
|
||||
response = self._client.get_object(self.bucket, object_name)
|
||||
return response.data
|
||||
except Exception as err:
|
||||
raise Exception("Failed getting object from s3 client") from err
|
||||
@ -50,20 +55,20 @@ class S3Storage(Storage):
|
||||
response.close()
|
||||
response.release_conn()
|
||||
|
||||
def get_all_objects(self, bucket_name):
|
||||
for obj in self._client.list_objects(bucket_name, recursive=True):
|
||||
def get_all_objects(self):
|
||||
for obj in self._client.list_objects(self.bucket, recursive=True):
|
||||
logger.debug(f"Downloading '{obj.object_name}'...")
|
||||
yield self.get_object(bucket_name, obj.object_name)
|
||||
yield self.get_object(obj.object_name)
|
||||
|
||||
def clear_bucket(self, bucket_name):
|
||||
logger.debug(f"Clearing S3 bucket '{bucket_name}'...")
|
||||
objects = self._client.list_objects(bucket_name, recursive=True)
|
||||
def clear_bucket(self):
|
||||
logger.debug(f"Clearing S3 bucket '{self.bucket}'...")
|
||||
objects = self._client.list_objects(self.bucket, recursive=True)
|
||||
for obj in objects:
|
||||
self._client.remove_object(bucket_name, obj.object_name)
|
||||
self._client.remove_object(self.bucket, obj.object_name)
|
||||
|
||||
def get_all_object_names(self, bucket_name):
|
||||
objs = self._client.list_objects(bucket_name, recursive=True)
|
||||
return zip(repeat(bucket_name), map(attrgetter("object_name"), objs))
|
||||
def get_all_object_names(self):
|
||||
objs = self._client.list_objects(self.bucket, recursive=True)
|
||||
return zip(repeat(self.bucket), map(attrgetter("object_name"), objs))
|
||||
|
||||
|
||||
def get_s3_storage_from_settings(settings: Dynaconf):
|
||||
@ -72,11 +77,12 @@ def get_s3_storage_from_settings(settings: Dynaconf):
|
||||
secure, endpoint = validate_and_parse_s3_endpoint(settings.storage.s3.endpoint)
|
||||
|
||||
return S3Storage(
|
||||
Minio(
|
||||
client=Minio(
|
||||
secure=secure,
|
||||
endpoint=endpoint,
|
||||
access_key=settings.storage.s3.key,
|
||||
secret_key=settings.storage.s3.secret,
|
||||
region=settings.storage.s3.region,
|
||||
)
|
||||
),
|
||||
bucket=settings.storage.s3.bucket,
|
||||
)
|
||||
|
||||
@ -15,6 +15,7 @@ queue_manager_validators = [
|
||||
|
||||
azure_storage_validators = [
|
||||
Validator("storage.azure.connection_string", must_exist=True),
|
||||
Validator("storage.azure.container", must_exist=True),
|
||||
]
|
||||
|
||||
s3_storage_validators = [
|
||||
@ -22,12 +23,19 @@ s3_storage_validators = [
|
||||
Validator("storage.s3.key", must_exist=True),
|
||||
Validator("storage.s3.secret", must_exist=True),
|
||||
Validator("storage.s3.region", must_exist=True),
|
||||
Validator("storage.s3.bucket", must_exist=True),
|
||||
]
|
||||
|
||||
storage_validators = [
|
||||
Validator("storage.backend", must_exist=True),
|
||||
]
|
||||
|
||||
multi_tenant_storage_validators = [
|
||||
Validator("storage.tenant_server.endpoint", must_exist=True),
|
||||
Validator("storage.tenant_server.public_key", must_exist=True),
|
||||
]
|
||||
|
||||
|
||||
prometheus_validators = [
|
||||
Validator("metrics.prometheus.prefix", must_exist=True),
|
||||
Validator("metrics.prometheus.enabled", must_exist=True),
|
||||
|
||||
@ -8,11 +8,11 @@ from fastapi import FastAPI
|
||||
from pyinfra.utils.config_validation import validate_settings, webserver_validators
|
||||
|
||||
|
||||
def create_webserver_thread(app: FastAPI, settings: Dynaconf) -> threading.Thread:
|
||||
def create_webserver_thread_from_settings(app: FastAPI, settings: Dynaconf) -> threading.Thread:
|
||||
validate_settings(settings, validators=webserver_validators)
|
||||
|
||||
return threading.Thread(
|
||||
target=lambda: uvicorn.run(
|
||||
app, port=settings.webserver.port, host=settings.webserver.host, log_level=logging.WARNING
|
||||
)
|
||||
)
|
||||
return create_webserver_thread(app=app, port=settings.webserver.port, host=settings.webserver.host)
|
||||
|
||||
|
||||
def create_webserver_thread(app: FastAPI, port: int, host: str) -> threading.Thread:
|
||||
return threading.Thread(target=lambda: uvicorn.run(app, port=port, host=host, log_level=logging.WARNING))
|
||||
|
||||
@ -6,14 +6,14 @@ import requests
|
||||
from fastapi import FastAPI
|
||||
|
||||
from pyinfra.monitor.prometheus import add_prometheus_endpoint, make_prometheus_processing_time_decorator_from_settings
|
||||
from pyinfra.webserver import create_webserver_thread
|
||||
from pyinfra.webserver import create_webserver_thread_from_settings
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def app_with_prometheus_endpoint(settings):
|
||||
app = FastAPI()
|
||||
app = add_prometheus_endpoint(app)
|
||||
thread = create_webserver_thread(app, settings)
|
||||
thread = create_webserver_thread_from_settings(app, settings)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
sleep(1)
|
||||
|
||||
@ -1,64 +1,119 @@
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
from pyinfra.storage.storage import get_storage_from_settings
|
||||
from pyinfra.storage.connection import get_storage_from_settings, get_storage_from_tenant_id
|
||||
from pyinfra.utils.cipher import encrypt
|
||||
from pyinfra.webserver import create_webserver_thread
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def storage(storage_backend, bucket_name, settings):
|
||||
@pytest.fixture(scope="class")
|
||||
def storage(storage_backend, settings):
|
||||
settings.storage.backend = storage_backend
|
||||
|
||||
storage = get_storage_from_settings(settings)
|
||||
storage.make_bucket(bucket_name)
|
||||
storage.make_bucket()
|
||||
|
||||
yield storage
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.clear_bucket()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="session")
|
||||
@pytest.mark.parametrize("bucket_name", ["bucket"], scope="session")
|
||||
@pytest.fixture(scope="class")
|
||||
def tenant_server_mock(settings, tenant_server_host, tenant_server_port):
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/azure_tenant")
|
||||
def get_azure_storage_info():
|
||||
return {
|
||||
"azureStorageConnection": {
|
||||
"connectionString": encrypt(
|
||||
settings.storage.tenant_server.public_key, settings.storage.azure.connection_string
|
||||
),
|
||||
"containerName": settings.storage.azure.container,
|
||||
}
|
||||
}
|
||||
|
||||
@app.get("/s3_tenant")
|
||||
def get_s3_storage_info():
|
||||
return {
|
||||
"s3StorageConnection": {
|
||||
"endpoint": settings.storage.s3.endpoint,
|
||||
"key": settings.storage.s3.key,
|
||||
"secret": encrypt(settings.storage.tenant_server.public_key, settings.storage.s3.secret),
|
||||
"region": settings.storage.s3.region,
|
||||
"bucketName": settings.storage.s3.bucket,
|
||||
}
|
||||
}
|
||||
|
||||
thread = create_webserver_thread(app, tenant_server_port, tenant_server_host)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
sleep(1)
|
||||
yield
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
|
||||
class TestStorage:
|
||||
def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
data_received = storage.get_all_objects(bucket_name)
|
||||
def test_clearing_bucket_yields_empty_bucket(self, storage):
|
||||
storage.clear_bucket()
|
||||
data_received = storage.get_all_objects()
|
||||
assert not {*data_received}
|
||||
|
||||
def test_getting_object_put_in_bucket_is_object(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(bucket_name, "file", b"content")
|
||||
data_received = storage.get_object(bucket_name, "file")
|
||||
def test_getting_object_put_in_bucket_is_object(self, storage):
|
||||
storage.clear_bucket()
|
||||
storage.put_object("file", b"content")
|
||||
data_received = storage.get_object("file")
|
||||
assert b"content" == data_received
|
||||
|
||||
def test_object_put_in_bucket_exists_on_storage(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(bucket_name, "file", b"content")
|
||||
assert storage.exists(bucket_name, "file")
|
||||
def test_object_put_in_bucket_exists_on_storage(self, storage):
|
||||
storage.clear_bucket()
|
||||
storage.put_object("file", b"content")
|
||||
assert storage.exists("file")
|
||||
|
||||
def test_getting_nested_object_put_in_bucket_is_nested_object(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(bucket_name, "folder/file", b"content")
|
||||
data_received = storage.get_object(bucket_name, "folder/file")
|
||||
def test_getting_nested_object_put_in_bucket_is_nested_object(self, storage):
|
||||
storage.clear_bucket()
|
||||
storage.put_object("folder/file", b"content")
|
||||
data_received = storage.get_object("folder/file")
|
||||
assert b"content" == data_received
|
||||
|
||||
def test_getting_objects_put_in_bucket_are_objects(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(bucket_name, "file1", b"content 1")
|
||||
storage.put_object(bucket_name, "folder/file2", b"content 2")
|
||||
data_received = storage.get_all_objects(bucket_name)
|
||||
def test_getting_objects_put_in_bucket_are_objects(self, storage):
|
||||
storage.clear_bucket()
|
||||
storage.put_object("file1", b"content 1")
|
||||
storage.put_object("folder/file2", b"content 2")
|
||||
data_received = storage.get_all_objects()
|
||||
assert {b"content 1", b"content 2"} == {*data_received}
|
||||
|
||||
def test_make_bucket_produces_bucket(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.make_bucket(bucket_name)
|
||||
assert storage.has_bucket(bucket_name)
|
||||
def test_make_bucket_produces_bucket(self, storage):
|
||||
storage.clear_bucket()
|
||||
storage.make_bucket()
|
||||
assert storage.has_bucket()
|
||||
|
||||
def test_listing_bucket_files_yields_all_files_in_bucket(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(bucket_name, "file1", b"content 1")
|
||||
storage.put_object(bucket_name, "file2", b"content 2")
|
||||
full_names_received = storage.get_all_object_names(bucket_name)
|
||||
assert {(bucket_name, "file1"), (bucket_name, "file2")} == {*full_names_received}
|
||||
def test_listing_bucket_files_yields_all_files_in_bucket(self, storage):
|
||||
storage.clear_bucket()
|
||||
storage.put_object("file1", b"content 1")
|
||||
storage.put_object("file2", b"content 2")
|
||||
full_names_received = storage.get_all_object_names()
|
||||
assert {(storage.bucket, "file1"), (storage.bucket, "file2")} == {*full_names_received}
|
||||
|
||||
def test_data_loading_failure_raised_if_object_not_present(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
def test_data_loading_failure_raised_if_object_not_present(self, storage):
|
||||
storage.clear_bucket()
|
||||
with pytest.raises(Exception):
|
||||
storage.get_object(bucket_name, "folder/file")
|
||||
storage.get_object("folder/file")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tenant_id", ["azure_tenant", "s3_tenant"], scope="class")
|
||||
@pytest.mark.parametrize("tenant_server_host", ["localhost"], scope="class")
|
||||
@pytest.mark.parametrize("tenant_server_port", [8000], scope="class")
|
||||
class TestMultiTenantStorage:
|
||||
def test_storage_connection_from_tenant_id(
|
||||
self, tenant_id, tenant_server_mock, settings, tenant_server_host, tenant_server_port
|
||||
):
|
||||
settings["storage"]["tenant_server"]["endpoint"] = f"http://{tenant_server_host}:{tenant_server_port}"
|
||||
storage = get_storage_from_tenant_id(tenant_id, settings)
|
||||
|
||||
storage.put_object("file", b"content")
|
||||
data_received = storage.get_object("file")
|
||||
|
||||
assert b"content" == data_received
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user