refactor: multi tenant storage connection

This commit is contained in:
Julius Unverfehrt 2024-01-18 11:22:27 +01:00
parent 17c5eebdf6
commit ec5ad09fa8
13 changed files with 328 additions and 406 deletions

View File

@ -1,57 +0,0 @@
from funcy import identity
from operator import attrgetter
from prometheus_client import Summary, start_http_server, CollectorRegistry
from time import time
from typing import Callable, Any, Sized
from pyinfra.config import Config
class PrometheusMonitor:
def __init__(self, prefix: str, host: str, port: int):
"""Register the monitoring metrics and start a webserver where they can be scraped at the endpoint
http://{host}:{port}/prometheus
Args:
prefix: should per convention consist of {product_name}_{service_name}_{parameter_to_monitor}
parameter_to_monitor is defined by the result of the processing service.
"""
self.registry = CollectorRegistry()
self.entity_processing_time_sum = Summary(
f"{prefix}_processing_time", "Summed up average processing time per entity observed", registry=self.registry
)
start_http_server(port, host, self.registry)
def __call__(self, process_fn: Callable) -> Callable:
"""Monitor the runtime of a function and update the registered metric with the average runtime per resulting
element.
"""
return self._add_result_monitoring(process_fn)
def _add_result_monitoring(self, process_fn: Callable):
def inner(data: Any, **kwargs):
start = time()
result: Sized = process_fn(data, **kwargs)
runtime = time() - start
if not result:
return result
processing_time_per_entity = runtime / len(result)
self.entity_processing_time_sum.observe(processing_time_per_entity)
return result
return inner
def get_monitor_from_config(config: Config) -> Callable:
if config.monitoring_enabled:
return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_host", "prometheus_port")(config))
else:
return identity

View File

@ -0,0 +1,124 @@
from functools import lru_cache, partial
from typing import Callable
import requests
from dynaconf import Dynaconf
from funcy import compose
from kn_utils.logging import logger
from pyinfra.storage.storages.azure import get_azure_storage_from_settings
from pyinfra.storage.storages.interface import Storage
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
from pyinfra.utils.cipher import decrypt
from pyinfra.utils.compressing import get_decompressor, get_compressor
from pyinfra.utils.config_validation import validate_settings, storage_validators, multi_tenant_storage_validators
from pyinfra.utils.encoding import get_decoder, get_encoder
def get_storage(settings: Dynaconf, tenant_id: str = None) -> Storage:
"""Get storage connection based on settings.
If tenant_id is provided, gets storage connection information from tenant server instead.
The connections are cached based on the settings.cache_size value.
In the future, when the default storage from config is no longer needed (only multi-tenant storage will be used),
get_storage_from_tenant_id can replace this function directly.
"""
if tenant_id:
logger.info(f"Using tenant storage for {tenant_id}.")
return get_storage_from_tenant_id(tenant_id, settings)
else:
logger.info("Using default storage.")
return get_storage_from_settings(settings)
def get_storage_from_settings(settings: Dynaconf) -> Storage:
validate_settings(settings, storage_validators)
@lru_cache(maxsize=settings.storage.cache_size)
def _get_storage(backend: str) -> Storage:
return storage_dispatcher[backend](settings)
return _get_storage(settings.storage.backend)
def get_storage_from_tenant_id(tenant_id: str, settings: Dynaconf) -> Storage:
validate_settings(settings, multi_tenant_storage_validators)
@lru_cache(maxsize=settings.storage.cache_size)
def _get_storage(tenant: str, endpoint: str, public_key: str) -> Storage:
response = requests.get(f"{endpoint}/{tenant}").json()
maybe_azure = response.get("azureStorageConnection")
maybe_s3 = response.get("s3StorageConnection")
assert (maybe_azure or maybe_s3) and not (maybe_azure and maybe_s3), "Only one storage backend can be used."
if maybe_azure:
connection_string = decrypt(public_key, maybe_azure["connectionString"])
backend = "azure"
storage_settings = {
"storage": {
"azure": {
"connection_string": connection_string,
"container": maybe_azure["containerName"],
},
}
}
elif maybe_s3:
secret = decrypt(public_key, maybe_s3["secret"])
backend = "s3"
storage_settings = {
"storage": {
"s3": {
"endpoint": maybe_s3["endpoint"],
"key": maybe_s3["key"],
"secret": secret,
"region": maybe_s3["region"],
"bucket": maybe_s3["bucketName"],
},
}
}
else:
raise Exception(f"Unknown storage backend in {response}.")
storage_settings = Dynaconf()
storage_settings.update(settings)
storage = storage_dispatcher[backend](storage_settings)
return storage
return _get_storage(tenant_id, settings.storage.tenant_server.endpoint, settings.storage.tenant_server.public_key)
storage_dispatcher = {
"azure": get_azure_storage_from_settings,
"s3": get_s3_storage_from_settings,
}
@lru_cache(maxsize=10)
def make_downloader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable:
verify = partial(verify_existence, storage, bucket)
download = partial(storage.get_object, bucket)
decompress = get_decompressor(compression_type)
decode = get_decoder(file_type)
return compose(decode, decompress, download, verify)
@lru_cache(maxsize=10)
def make_uploader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable:
upload = partial(storage.put_object, bucket)
compress = get_compressor(compression_type)
encode = get_encoder(file_type)
def inner(file_name, file_bytes):
upload(file_name, compose(compress, encode)(file_bytes))
return inner
def verify_existence(storage: Storage, bucket: str, file_name: str) -> str:
if not storage.exists(file_name):
raise FileNotFoundError(f"{file_name=} name not found on storage in {storage.bucket=}.")
return file_name

View File

@ -1,51 +0,0 @@
from functools import lru_cache, partial
from typing import Callable
from dynaconf import Dynaconf
from funcy import compose
from pyinfra.storage.storages.interface import Storage
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
from pyinfra.utils.compressing import get_decompressor, get_compressor
from pyinfra.utils.config_validation import validate_settings, storage_validators
from pyinfra.utils.encoding import get_decoder, get_encoder
def get_storage_from_settings(settings: Dynaconf) -> Storage:
validate_settings(settings, storage_validators)
return storage_dispatcher[settings.storage.backend](settings)
storage_dispatcher = {
"azure": get_s3_storage_from_settings,
"s3": get_s3_storage_from_settings,
}
@lru_cache(maxsize=10)
def make_downloader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable:
verify = partial(verify_existence, storage, bucket)
download = partial(storage.get_object, bucket)
decompress = get_decompressor(compression_type)
decode = get_decoder(file_type)
return compose(decode, decompress, download, verify)
@lru_cache(maxsize=10)
def make_uploader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable:
upload = partial(storage.put_object, bucket)
compress = get_compressor(compression_type)
encode = get_encoder(file_type)
def inner(file_name, file_bytes):
upload(file_name, compose(compress, encode)(file_bytes))
return inner
def verify_existence(storage: Storage, bucket: str, file_name: str) -> str:
if not storage.exists(bucket, file_name):
raise FileNotFoundError(f"{file_name=} name not found on storage in {bucket=}.")
return file_name

View File

@ -1,125 +0,0 @@
from dataclasses import dataclass
import requests
from azure.storage.blob import BlobServiceClient
from minio import Minio
from pyinfra.config import Config
from pyinfra.exception import UnknownStorageBackend
from pyinfra.storage.storages.azure import AzureStorage
from pyinfra.storage.storages.interface import Storage
from pyinfra.storage.storages.s3 import S3Storage
from pyinfra.utils.cipher import decrypt
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
@dataclass(frozen=True)
class StorageInfo:
bucket_name: str
@dataclass(frozen=True)
class AzureStorageInfo(StorageInfo):
connection_string: str
def __hash__(self):
return hash(self.connection_string)
def __eq__(self, other):
if not isinstance(other, AzureStorageInfo):
return False
return self.connection_string == other.connection_string
@dataclass(frozen=True)
class S3StorageInfo(StorageInfo):
secure: bool
endpoint: str
access_key: str
secret_key: str
region: str
def __hash__(self):
return hash((self.secure, self.endpoint, self.access_key, self.secret_key, self.region))
def __eq__(self, other):
if not isinstance(other, S3StorageInfo):
return False
return (
self.secure == other.secure
and self.endpoint == other.endpoint
and self.access_key == other.access_key
and self.secret_key == other.secret_key
and self.region == other.region
)
def get_storage_from_storage_info(storage_info: StorageInfo) -> Storage:
if isinstance(storage_info, AzureStorageInfo):
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string))
elif isinstance(storage_info, S3StorageInfo):
return S3Storage(
Minio(
secure=storage_info.secure,
endpoint=storage_info.endpoint,
access_key=storage_info.access_key,
secret_key=storage_info.secret_key,
region=storage_info.region,
)
)
else:
raise UnknownStorageBackend()
def get_storage_info_from_endpoint(public_key: str, endpoint: str, x_tenant_id: str) -> StorageInfo:
resp = requests.get(f"{endpoint}/{x_tenant_id}").json()
maybe_azure = resp.get("azureStorageConnection")
maybe_s3 = resp.get("s3StorageConnection")
assert not (maybe_azure and maybe_s3)
if maybe_azure:
connection_string = decrypt(public_key, maybe_azure["connectionString"])
storage_info = AzureStorageInfo(
connection_string=connection_string,
bucket_name=maybe_azure["containerName"],
)
elif maybe_s3:
secure, endpoint = validate_and_parse_s3_endpoint(maybe_s3["endpoint"])
secret = decrypt(public_key, maybe_s3["secret"])
storage_info = S3StorageInfo(
secure=secure,
endpoint=endpoint,
access_key=maybe_s3["key"],
secret_key=secret,
region=maybe_s3["region"],
bucket_name=maybe_s3["bucketName"],
)
else:
raise UnknownStorageBackend()
return storage_info
def get_storage_info_from_config(config: Config) -> StorageInfo:
if config.storage_backend == "s3":
storage_info = S3StorageInfo(
secure=config.storage_secure_connection,
endpoint=config.storage_endpoint,
access_key=config.storage_key,
secret_key=config.storage_secret,
region=config.storage_region,
bucket_name=config.storage_bucket,
)
elif config.storage_backend == "azure":
storage_info = AzureStorageInfo(
connection_string=config.storage_azureconnectionstring,
bucket_name=config.storage_bucket,
)
else:
raise UnknownStorageBackend(f"Unknown storage backend '{config.storage_backend}'.")
return storage_info

View File

@ -1,55 +0,0 @@
from dataclasses import asdict
from functools import partial, lru_cache
from kn_utils.logging import logger
from typing import Tuple
from pyinfra.config import Config
from pyinfra.storage.storage_info import (
get_storage_info_from_config,
get_storage_info_from_endpoint,
StorageInfo,
get_storage_from_storage_info,
)
from pyinfra.storage.storages.interface import Storage
class StorageProvider:
def __init__(self, config: Config):
self.config = config
self.default_storage_info: StorageInfo = get_storage_info_from_config(config)
self.get_storage_info_from_tenant_id = partial(
get_storage_info_from_endpoint,
config.tenant_decryption_public_key,
config.tenant_endpoint,
)
def __call__(self, *args, **kwargs):
return self._connect(*args, **kwargs)
@lru_cache(maxsize=32)
def _connect(self, x_tenant_id=None) -> Tuple[Storage, StorageInfo]:
storage_info = self._get_storage_info(x_tenant_id)
storage_connection = get_storage_from_storage_info(storage_info)
return storage_connection, storage_info
def _get_storage_info(self, x_tenant_id=None):
if x_tenant_id:
storage_info = self.get_storage_info_from_tenant_id(x_tenant_id)
logger.debug(f"Received {storage_info.__class__.__name__} for {x_tenant_id} from endpoint.")
logger.trace(f"{asdict(storage_info)}")
else:
storage_info = self.default_storage_info
logger.debug(f"Using local default {storage_info.__class__.__name__} for {x_tenant_id}.")
logger.trace(f"{asdict(storage_info)}")
return storage_info
class StorageProviderMock(StorageProvider):
def __init__(self, storage, storage_info):
self.storage = storage
self.storage_info = storage_info
def __call__(self, *args, **kwargs):
return self.storage, self.storage_info

View File

@ -15,47 +15,52 @@ logging.getLogger("urllib3").setLevel(logging.WARNING)
class AzureStorage(Storage):
def __init__(self, client: BlobServiceClient):
def __init__(self, client: BlobServiceClient, bucket: str):
self._client: BlobServiceClient = client
self._bucket = bucket
def has_bucket(self, bucket_name):
container_client = self._client.get_container_client(bucket_name)
@property
def bucket(self):
return self._bucket
def has_bucket(self):
container_client = self._client.get_container_client(self.bucket)
return container_client.exists()
def make_bucket(self, bucket_name):
container_client = self._client.get_container_client(bucket_name)
container_client if container_client.exists() else self._client.create_container(bucket_name)
def make_bucket(self):
container_client = self._client.get_container_client(self.bucket)
container_client if container_client.exists() else self._client.create_container(self.bucket)
def __provide_container_client(self, bucket_name) -> ContainerClient:
self.make_bucket(bucket_name)
container_client = self._client.get_container_client(bucket_name)
def __provide_container_client(self) -> ContainerClient:
self.make_bucket()
container_client = self._client.get_container_client(self.bucket)
return container_client
def put_object(self, bucket_name, object_name, data):
def put_object(self, object_name, data):
logger.debug(f"Uploading '{object_name}'...")
container_client = self.__provide_container_client(bucket_name)
container_client = self.__provide_container_client()
blob_client = container_client.get_blob_client(object_name)
blob_client.upload_blob(data, overwrite=True)
def exists(self, bucket_name, object_name):
container_client = self.__provide_container_client(bucket_name)
def exists(self, object_name):
container_client = self.__provide_container_client()
blob_client = container_client.get_blob_client(object_name)
return blob_client.exists()
@retry(tries=3, delay=5, jitter=(1, 3))
def get_object(self, bucket_name, object_name):
def get_object(self, object_name):
logger.debug(f"Downloading '{object_name}'...")
try:
container_client = self.__provide_container_client(bucket_name)
container_client = self.__provide_container_client()
blob_client = container_client.get_blob_client(object_name)
blob_data = blob_client.download_blob()
return blob_data.readall()
except Exception as err:
raise Exception("Failed getting object from azure client") from err
def get_all_objects(self, bucket_name):
container_client = self.__provide_container_client(bucket_name)
def get_all_objects(self):
container_client = self.__provide_container_client()
blobs = container_client.list_blobs()
for blob in blobs:
logger.debug(f"Downloading '{blob.name}'...")
@ -64,18 +69,22 @@ class AzureStorage(Storage):
data = blob_data.readall()
yield data
def clear_bucket(self, bucket_name):
logger.debug(f"Clearing Azure container '{bucket_name}'...")
container_client = self._client.get_container_client(bucket_name)
def clear_bucket(self):
logger.debug(f"Clearing Azure container '{self.bucket}'...")
container_client = self._client.get_container_client(self.bucket)
blobs = container_client.list_blobs()
container_client.delete_blobs(*blobs)
def get_all_object_names(self, bucket_name):
container_client = self.__provide_container_client(bucket_name)
def get_all_object_names(self):
container_client = self.__provide_container_client()
blobs = container_client.list_blobs()
return zip(repeat(bucket_name), map(attrgetter("name"), blobs))
return zip(repeat(self.bucket), map(attrgetter("name"), blobs))
def get_azure_storage_from_settings(settings: Dynaconf):
validate_settings(settings, azure_storage_validators)
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=settings.storage.azure.connection_string))
return AzureStorage(
client=BlobServiceClient.from_connection_string(conn_str=settings.storage.azure.connection_string),
bucket=settings.storage.azure.container,
)

View File

@ -2,34 +2,39 @@ from abc import ABC, abstractmethod
class Storage(ABC):
@property
@abstractmethod
def make_bucket(self, bucket_name):
def bucket(self):
raise NotImplementedError
@abstractmethod
def has_bucket(self, bucket_name):
def make_bucket(self):
raise NotImplementedError
@abstractmethod
def put_object(self, bucket_name, object_name, data):
def has_bucket(self):
raise NotImplementedError
@abstractmethod
def exists(self, bucket_name, object_name):
def put_object(self, object_name, data):
raise NotImplementedError
@abstractmethod
def get_object(self, bucket_name, object_name):
def exists(self, object_name):
raise NotImplementedError
@abstractmethod
def get_all_objects(self, bucket_name):
def get_object(self, object_name):
raise NotImplementedError
@abstractmethod
def clear_bucket(self, bucket_name):
def get_all_objects(self):
raise NotImplementedError
@abstractmethod
def get_all_object_names(self, bucket_name):
def clear_bucket(self):
raise NotImplementedError
@abstractmethod
def get_all_object_names(self):
raise NotImplementedError

View File

@ -5,32 +5,35 @@ class StorageMock(Storage):
def __init__(self, data: bytes = None, file_name: str = None, bucket: str = None):
self.data = data
self.file_name = file_name
self.bucket = bucket
self._bucket = bucket
def make_bucket(self, bucket_name):
self.bucket = bucket_name
@property
def bucket(self):
return self._bucket
def has_bucket(self, bucket_name):
return self.bucket == bucket_name
def make_bucket(self):
pass
def put_object(self, bucket_name, object_name, data):
self.bucket = bucket_name
def has_bucket(self):
return True
def put_object(self, object_name, data):
self.file_name = object_name
self.data = data
def exists(self, bucket_name, object_name):
return self.bucket == bucket_name and self.file_name == object_name
def exists(self, object_name):
return self.file_name == object_name
def get_object(self, bucket_name, object_name):
def get_object(self, object_name):
return self.data
def get_all_objects(self, bucket_name):
def get_all_objects(self):
raise NotImplementedError
def clear_bucket(self, bucket_name):
self.bucket = None
def clear_bucket(self):
self._bucket = None
self.file_name = None
self.data = None
def get_all_object_names(self, bucket_name):
def get_all_object_names(self):
raise NotImplementedError

View File

@ -13,35 +13,40 @@ from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
class S3Storage(Storage):
def __init__(self, client: Minio):
def __init__(self, client: Minio, bucket: str):
self._client = client
self._bucket = bucket
def make_bucket(self, bucket_name):
if not self.has_bucket(bucket_name):
self._client.make_bucket(bucket_name)
@property
def bucket(self):
return self._bucket
def has_bucket(self, bucket_name):
return self._client.bucket_exists(bucket_name)
def make_bucket(self):
if not self.has_bucket():
self._client.make_bucket(self.bucket)
def put_object(self, bucket_name, object_name, data):
def has_bucket(self):
return self._client.bucket_exists(self.bucket)
def put_object(self, object_name, data):
logger.debug(f"Uploading '{object_name}'...")
data = io.BytesIO(data)
self._client.put_object(bucket_name, object_name, data, length=data.getbuffer().nbytes)
self._client.put_object(self.bucket, object_name, data, length=data.getbuffer().nbytes)
def exists(self, bucket_name, object_name):
def exists(self, object_name):
try:
self._client.stat_object(bucket_name, object_name)
self._client.stat_object(self.bucket, object_name)
return True
except Exception:
return False
@retry(tries=3, delay=5, jitter=(1, 3))
def get_object(self, bucket_name, object_name):
def get_object(self, object_name):
logger.debug(f"Downloading '{object_name}'...")
response = None
try:
response = self._client.get_object(bucket_name, object_name)
response = self._client.get_object(self.bucket, object_name)
return response.data
except Exception as err:
raise Exception("Failed getting object from s3 client") from err
@ -50,20 +55,20 @@ class S3Storage(Storage):
response.close()
response.release_conn()
def get_all_objects(self, bucket_name):
for obj in self._client.list_objects(bucket_name, recursive=True):
def get_all_objects(self):
for obj in self._client.list_objects(self.bucket, recursive=True):
logger.debug(f"Downloading '{obj.object_name}'...")
yield self.get_object(bucket_name, obj.object_name)
yield self.get_object(obj.object_name)
def clear_bucket(self, bucket_name):
logger.debug(f"Clearing S3 bucket '{bucket_name}'...")
objects = self._client.list_objects(bucket_name, recursive=True)
def clear_bucket(self):
logger.debug(f"Clearing S3 bucket '{self.bucket}'...")
objects = self._client.list_objects(self.bucket, recursive=True)
for obj in objects:
self._client.remove_object(bucket_name, obj.object_name)
self._client.remove_object(self.bucket, obj.object_name)
def get_all_object_names(self, bucket_name):
objs = self._client.list_objects(bucket_name, recursive=True)
return zip(repeat(bucket_name), map(attrgetter("object_name"), objs))
def get_all_object_names(self):
objs = self._client.list_objects(self.bucket, recursive=True)
return zip(repeat(self.bucket), map(attrgetter("object_name"), objs))
def get_s3_storage_from_settings(settings: Dynaconf):
@ -72,11 +77,12 @@ def get_s3_storage_from_settings(settings: Dynaconf):
secure, endpoint = validate_and_parse_s3_endpoint(settings.storage.s3.endpoint)
return S3Storage(
Minio(
client=Minio(
secure=secure,
endpoint=endpoint,
access_key=settings.storage.s3.key,
secret_key=settings.storage.s3.secret,
region=settings.storage.s3.region,
)
),
bucket=settings.storage.s3.bucket,
)

View File

@ -15,6 +15,7 @@ queue_manager_validators = [
azure_storage_validators = [
Validator("storage.azure.connection_string", must_exist=True),
Validator("storage.azure.container", must_exist=True),
]
s3_storage_validators = [
@ -22,12 +23,19 @@ s3_storage_validators = [
Validator("storage.s3.key", must_exist=True),
Validator("storage.s3.secret", must_exist=True),
Validator("storage.s3.region", must_exist=True),
Validator("storage.s3.bucket", must_exist=True),
]
storage_validators = [
Validator("storage.backend", must_exist=True),
]
multi_tenant_storage_validators = [
Validator("storage.tenant_server.endpoint", must_exist=True),
Validator("storage.tenant_server.public_key", must_exist=True),
]
prometheus_validators = [
Validator("metrics.prometheus.prefix", must_exist=True),
Validator("metrics.prometheus.enabled", must_exist=True),

View File

@ -8,11 +8,11 @@ from fastapi import FastAPI
from pyinfra.utils.config_validation import validate_settings, webserver_validators
def create_webserver_thread(app: FastAPI, settings: Dynaconf) -> threading.Thread:
def create_webserver_thread_from_settings(app: FastAPI, settings: Dynaconf) -> threading.Thread:
validate_settings(settings, validators=webserver_validators)
return threading.Thread(
target=lambda: uvicorn.run(
app, port=settings.webserver.port, host=settings.webserver.host, log_level=logging.WARNING
)
)
return create_webserver_thread(app=app, port=settings.webserver.port, host=settings.webserver.host)
def create_webserver_thread(app: FastAPI, port: int, host: str) -> threading.Thread:
return threading.Thread(target=lambda: uvicorn.run(app, port=port, host=host, log_level=logging.WARNING))

View File

@ -6,14 +6,14 @@ import requests
from fastapi import FastAPI
from pyinfra.monitor.prometheus import add_prometheus_endpoint, make_prometheus_processing_time_decorator_from_settings
from pyinfra.webserver import create_webserver_thread
from pyinfra.webserver import create_webserver_thread_from_settings
@pytest.fixture(scope="class")
def app_with_prometheus_endpoint(settings):
app = FastAPI()
app = add_prometheus_endpoint(app)
thread = create_webserver_thread(app, settings)
thread = create_webserver_thread_from_settings(app, settings)
thread.daemon = True
thread.start()
sleep(1)

View File

@ -1,64 +1,119 @@
from time import sleep
import pytest
from fastapi import FastAPI
from pyinfra.storage.storage import get_storage_from_settings
from pyinfra.storage.connection import get_storage_from_settings, get_storage_from_tenant_id
from pyinfra.utils.cipher import encrypt
from pyinfra.webserver import create_webserver_thread
@pytest.fixture(scope="session")
def storage(storage_backend, bucket_name, settings):
@pytest.fixture(scope="class")
def storage(storage_backend, settings):
settings.storage.backend = storage_backend
storage = get_storage_from_settings(settings)
storage.make_bucket(bucket_name)
storage.make_bucket()
yield storage
storage.clear_bucket(bucket_name)
storage.clear_bucket()
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="session")
@pytest.mark.parametrize("bucket_name", ["bucket"], scope="session")
@pytest.fixture(scope="class")
def tenant_server_mock(settings, tenant_server_host, tenant_server_port):
app = FastAPI()
@app.get("/azure_tenant")
def get_azure_storage_info():
return {
"azureStorageConnection": {
"connectionString": encrypt(
settings.storage.tenant_server.public_key, settings.storage.azure.connection_string
),
"containerName": settings.storage.azure.container,
}
}
@app.get("/s3_tenant")
def get_s3_storage_info():
return {
"s3StorageConnection": {
"endpoint": settings.storage.s3.endpoint,
"key": settings.storage.s3.key,
"secret": encrypt(settings.storage.tenant_server.public_key, settings.storage.s3.secret),
"region": settings.storage.s3.region,
"bucketName": settings.storage.s3.bucket,
}
}
thread = create_webserver_thread(app, tenant_server_port, tenant_server_host)
thread.daemon = True
thread.start()
sleep(1)
yield
thread.join(timeout=1)
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
class TestStorage:
def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
data_received = storage.get_all_objects(bucket_name)
def test_clearing_bucket_yields_empty_bucket(self, storage):
storage.clear_bucket()
data_received = storage.get_all_objects()
assert not {*data_received}
def test_getting_object_put_in_bucket_is_object(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, "file", b"content")
data_received = storage.get_object(bucket_name, "file")
def test_getting_object_put_in_bucket_is_object(self, storage):
storage.clear_bucket()
storage.put_object("file", b"content")
data_received = storage.get_object("file")
assert b"content" == data_received
def test_object_put_in_bucket_exists_on_storage(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, "file", b"content")
assert storage.exists(bucket_name, "file")
def test_object_put_in_bucket_exists_on_storage(self, storage):
storage.clear_bucket()
storage.put_object("file", b"content")
assert storage.exists("file")
def test_getting_nested_object_put_in_bucket_is_nested_object(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, "folder/file", b"content")
data_received = storage.get_object(bucket_name, "folder/file")
def test_getting_nested_object_put_in_bucket_is_nested_object(self, storage):
storage.clear_bucket()
storage.put_object("folder/file", b"content")
data_received = storage.get_object("folder/file")
assert b"content" == data_received
def test_getting_objects_put_in_bucket_are_objects(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, "file1", b"content 1")
storage.put_object(bucket_name, "folder/file2", b"content 2")
data_received = storage.get_all_objects(bucket_name)
def test_getting_objects_put_in_bucket_are_objects(self, storage):
storage.clear_bucket()
storage.put_object("file1", b"content 1")
storage.put_object("folder/file2", b"content 2")
data_received = storage.get_all_objects()
assert {b"content 1", b"content 2"} == {*data_received}
def test_make_bucket_produces_bucket(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.make_bucket(bucket_name)
assert storage.has_bucket(bucket_name)
def test_make_bucket_produces_bucket(self, storage):
storage.clear_bucket()
storage.make_bucket()
assert storage.has_bucket()
def test_listing_bucket_files_yields_all_files_in_bucket(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, "file1", b"content 1")
storage.put_object(bucket_name, "file2", b"content 2")
full_names_received = storage.get_all_object_names(bucket_name)
assert {(bucket_name, "file1"), (bucket_name, "file2")} == {*full_names_received}
def test_listing_bucket_files_yields_all_files_in_bucket(self, storage):
storage.clear_bucket()
storage.put_object("file1", b"content 1")
storage.put_object("file2", b"content 2")
full_names_received = storage.get_all_object_names()
assert {(storage.bucket, "file1"), (storage.bucket, "file2")} == {*full_names_received}
def test_data_loading_failure_raised_if_object_not_present(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
def test_data_loading_failure_raised_if_object_not_present(self, storage):
storage.clear_bucket()
with pytest.raises(Exception):
storage.get_object(bucket_name, "folder/file")
storage.get_object("folder/file")
@pytest.mark.parametrize("tenant_id", ["azure_tenant", "s3_tenant"], scope="class")
@pytest.mark.parametrize("tenant_server_host", ["localhost"], scope="class")
@pytest.mark.parametrize("tenant_server_port", [8000], scope="class")
class TestMultiTenantStorage:
def test_storage_connection_from_tenant_id(
self, tenant_id, tenant_server_mock, settings, tenant_server_host, tenant_server_port
):
settings["storage"]["tenant_server"]["endpoint"] = f"http://{tenant_server_host}:{tenant_server_port}"
storage = get_storage_from_tenant_id(tenant_id, settings)
storage.put_object("file", b"content")
data_received = storage.get_object("file")
assert b"content" == data_received