The download function is now overloaded and additionlly supports a dict with file paths as values, in addition to the present string as file path. The data is forwarded as dict of the same structure in the first case.
177 lines
6.3 KiB
Python
177 lines
6.3 KiB
Python
import gzip
|
|
import json
|
|
from time import sleep
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
|
|
from pyinfra.storage.connection import get_storage_for_tenant
|
|
from pyinfra.storage.utils import (
|
|
download_data_as_specified_in_message,
|
|
upload_data_as_specified_in_message,
|
|
)
|
|
from pyinfra.utils.cipher import encrypt
|
|
from pyinfra.webserver.utils import create_webserver_thread
|
|
|
|
|
|
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
|
|
class TestStorage:
|
|
def test_clearing_bucket_yields_empty_bucket(self, storage):
|
|
storage.clear_bucket()
|
|
data_received = storage.get_all_objects()
|
|
assert not {*data_received}
|
|
|
|
def test_getting_object_put_in_bucket_is_object(self, storage):
|
|
storage.clear_bucket()
|
|
storage.put_object("file", b"content")
|
|
data_received = storage.get_object("file")
|
|
assert b"content" == data_received
|
|
|
|
def test_object_put_in_bucket_exists_on_storage(self, storage):
|
|
storage.clear_bucket()
|
|
storage.put_object("file", b"content")
|
|
assert storage.exists("file")
|
|
|
|
def test_getting_nested_object_put_in_bucket_is_nested_object(self, storage):
|
|
storage.clear_bucket()
|
|
storage.put_object("folder/file", b"content")
|
|
data_received = storage.get_object("folder/file")
|
|
assert b"content" == data_received
|
|
|
|
def test_getting_objects_put_in_bucket_are_objects(self, storage):
|
|
storage.clear_bucket()
|
|
storage.put_object("file1", b"content 1")
|
|
storage.put_object("folder/file2", b"content 2")
|
|
data_received = storage.get_all_objects()
|
|
assert {b"content 1", b"content 2"} == {*data_received}
|
|
|
|
def test_make_bucket_produces_bucket(self, storage):
|
|
storage.clear_bucket()
|
|
storage.make_bucket()
|
|
assert storage.has_bucket()
|
|
|
|
def test_listing_bucket_files_yields_all_files_in_bucket(self, storage):
|
|
storage.clear_bucket()
|
|
storage.put_object("file1", b"content 1")
|
|
storage.put_object("file2", b"content 2")
|
|
full_names_received = storage.get_all_object_names()
|
|
assert {(storage.bucket, "file1"), (storage.bucket, "file2")} == {*full_names_received}
|
|
|
|
def test_data_loading_failure_raised_if_object_not_present(self, storage):
|
|
storage.clear_bucket()
|
|
with pytest.raises(Exception):
|
|
storage.get_object("folder/file")
|
|
|
|
|
|
@pytest.fixture(scope="class")
|
|
def tenant_server_mock(settings, tenant_server_host, tenant_server_port):
|
|
app = FastAPI()
|
|
|
|
@app.get("/azure_tenant")
|
|
def get_azure_storage_info():
|
|
return {
|
|
"azureStorageConnection": {
|
|
"connectionString": encrypt(
|
|
settings.storage.tenant_server.public_key, settings.storage.azure.connection_string
|
|
),
|
|
"containerName": settings.storage.azure.container,
|
|
}
|
|
}
|
|
|
|
@app.get("/s3_tenant")
|
|
def get_s3_storage_info():
|
|
return {
|
|
"s3StorageConnection": {
|
|
"endpoint": settings.storage.s3.endpoint,
|
|
"key": settings.storage.s3.key,
|
|
"secret": encrypt(settings.storage.tenant_server.public_key, settings.storage.s3.secret),
|
|
"region": settings.storage.s3.region,
|
|
"bucketName": settings.storage.s3.bucket,
|
|
}
|
|
}
|
|
|
|
thread = create_webserver_thread(app, tenant_server_port, tenant_server_host)
|
|
thread.daemon = True
|
|
thread.start()
|
|
sleep(1)
|
|
yield
|
|
thread.join(timeout=1)
|
|
|
|
|
|
@pytest.mark.parametrize("tenant_id", ["azure_tenant", "s3_tenant"], scope="class")
|
|
@pytest.mark.parametrize("tenant_server_host", ["localhost"], scope="class")
|
|
@pytest.mark.parametrize("tenant_server_port", [8000], scope="class")
|
|
class TestMultiTenantStorage:
|
|
def test_storage_connection_from_tenant_id(
|
|
self, tenant_id, tenant_server_mock, settings, tenant_server_host, tenant_server_port
|
|
):
|
|
settings["storage"]["tenant_server"]["endpoint"] = f"http://{tenant_server_host}:{tenant_server_port}"
|
|
storage = get_storage_for_tenant(
|
|
tenant_id,
|
|
settings["storage"]["tenant_server"]["endpoint"],
|
|
settings["storage"]["tenant_server"]["public_key"],
|
|
)
|
|
|
|
storage.put_object("file", b"content")
|
|
data_received = storage.get_object("file")
|
|
|
|
assert b"content" == data_received
|
|
|
|
|
|
@pytest.fixture
|
|
def payload(payload_type):
|
|
if payload_type == "target_response_file_path":
|
|
return {
|
|
"targetFilePath": "test/file.target.json.gz",
|
|
"responseFilePath": "test/file.response.json.gz",
|
|
}
|
|
elif payload_type == "dossier_id_file_id":
|
|
return {
|
|
"dossierId": "test",
|
|
"fileId": "file",
|
|
"targetFileExtension": "target.json.gz",
|
|
"responseFileExtension": "response.json.gz",
|
|
}
|
|
elif payload_type == "target_file_dict":
|
|
return {
|
|
"targetFilePath": {"file_1": "test/file.target.json.gz", "file_2": "test/file.target.json.gz"},
|
|
"responseFilePath": "test/file.response.json.gz",
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def expected_data(payload_type):
|
|
if payload_type == "target_response_file_path":
|
|
return {"data": "success"}
|
|
elif payload_type == "dossier_id_file_id":
|
|
return {"dossierId": "test", "fileId": "file", "data": "success"}
|
|
elif payload_type == "target_file_dict":
|
|
return {"file_1": {"data": "success"}, "file_2": {"data": "success"}}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"payload_type",
|
|
[
|
|
"target_response_file_path",
|
|
"dossier_id_file_id",
|
|
"target_file_dict",
|
|
],
|
|
scope="class",
|
|
)
|
|
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
|
|
class TestDownloadAndUploadFromMessage:
|
|
def test_download_and_upload_from_message(self, storage, payload, expected_data, payload_type):
|
|
storage.clear_bucket()
|
|
|
|
upload_data = expected_data if payload_type != "target_file_dict" else expected_data["file_1"]
|
|
storage.put_object("test/file.target.json.gz", gzip.compress(json.dumps(upload_data).encode()))
|
|
|
|
data = download_data_as_specified_in_message(storage, payload)
|
|
|
|
assert data == expected_data
|
|
|
|
upload_data_as_specified_in_message(storage, payload, expected_data)
|
|
data = json.loads(gzip.decompress(storage.get_object("test/file.response.json.gz")).decode())
|
|
|
|
assert data == {**payload, "data": expected_data}
|