177 lines
6.3 KiB
Python
177 lines
6.3 KiB
Python
import gzip
|
|
import json
|
|
from time import sleep
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
|
|
from pyinfra.storage.connection import get_storage_for_tenant
|
|
from pyinfra.storage.utils import (
|
|
download_data_as_specified_in_message,
|
|
upload_data_as_specified_in_message,
|
|
)
|
|
from pyinfra.utils.cipher import encrypt
|
|
from pyinfra.webserver.utils import create_webserver_thread
|
|
|
|
|
|
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
|
|
class TestStorage:
|
|
def test_clearing_bucket_yields_empty_bucket(self, storage):
|
|
storage.clear_bucket()
|
|
data_received = storage.get_all_objects()
|
|
assert not {*data_received}
|
|
|
|
def test_getting_object_put_in_bucket_is_object(self, storage):
|
|
storage.clear_bucket()
|
|
storage.put_object("file", b"content")
|
|
data_received = storage.get_object("file")
|
|
assert b"content" == data_received
|
|
|
|
def test_object_put_in_bucket_exists_on_storage(self, storage):
|
|
storage.clear_bucket()
|
|
storage.put_object("file", b"content")
|
|
assert storage.exists("file")
|
|
|
|
def test_getting_nested_object_put_in_bucket_is_nested_object(self, storage):
|
|
storage.clear_bucket()
|
|
storage.put_object("folder/file", b"content")
|
|
data_received = storage.get_object("folder/file")
|
|
assert b"content" == data_received
|
|
|
|
def test_getting_objects_put_in_bucket_are_objects(self, storage):
|
|
storage.clear_bucket()
|
|
storage.put_object("file1", b"content 1")
|
|
storage.put_object("folder/file2", b"content 2")
|
|
data_received = storage.get_all_objects()
|
|
assert {b"content 1", b"content 2"} == {*data_received}
|
|
|
|
def test_make_bucket_produces_bucket(self, storage):
|
|
storage.clear_bucket()
|
|
storage.make_bucket()
|
|
assert storage.has_bucket()
|
|
|
|
def test_listing_bucket_files_yields_all_files_in_bucket(self, storage):
|
|
storage.clear_bucket()
|
|
storage.put_object("file1", b"content 1")
|
|
storage.put_object("file2", b"content 2")
|
|
full_names_received = storage.get_all_object_names()
|
|
assert {(storage.bucket, "file1"), (storage.bucket, "file2")} == {*full_names_received}
|
|
|
|
def test_data_loading_failure_raised_if_object_not_present(self, storage):
|
|
storage.clear_bucket()
|
|
with pytest.raises(Exception):
|
|
storage.get_object("folder/file")
|
|
|
|
|
|
@pytest.fixture(scope="class")
|
|
def tenant_server_mock(settings, tenant_server_host, tenant_server_port):
|
|
app = FastAPI()
|
|
|
|
@app.get("/azure_tenant")
|
|
def get_azure_storage_info():
|
|
return {
|
|
"azureStorageConnection": {
|
|
"connectionString": encrypt(
|
|
settings.storage.tenant_server.public_key, settings.storage.azure.connection_string
|
|
),
|
|
"containerName": settings.storage.azure.container,
|
|
}
|
|
}
|
|
|
|
@app.get("/s3_tenant")
|
|
def get_s3_storage_info():
|
|
return {
|
|
"s3StorageConnection": {
|
|
"endpoint": settings.storage.s3.endpoint,
|
|
"key": settings.storage.s3.key,
|
|
"secret": encrypt(settings.storage.tenant_server.public_key, settings.storage.s3.secret),
|
|
"region": settings.storage.s3.region,
|
|
"bucketName": settings.storage.s3.bucket,
|
|
}
|
|
}
|
|
|
|
thread = create_webserver_thread(app, tenant_server_port, tenant_server_host)
|
|
thread.daemon = True
|
|
thread.start()
|
|
sleep(1)
|
|
yield
|
|
thread.join(timeout=1)
|
|
|
|
|
|
@pytest.mark.parametrize("tenant_id", ["azure_tenant", "s3_tenant"], scope="class")
|
|
@pytest.mark.parametrize("tenant_server_host", ["localhost"], scope="class")
|
|
@pytest.mark.parametrize("tenant_server_port", [8000], scope="class")
|
|
class TestMultiTenantStorage:
|
|
def test_storage_connection_from_tenant_id(
|
|
self, tenant_id, tenant_server_mock, settings, tenant_server_host, tenant_server_port
|
|
):
|
|
settings["storage"]["tenant_server"]["endpoint"] = f"http://{tenant_server_host}:{tenant_server_port}"
|
|
storage = get_storage_for_tenant(
|
|
tenant_id,
|
|
settings["storage"]["tenant_server"]["endpoint"],
|
|
settings["storage"]["tenant_server"]["public_key"],
|
|
)
|
|
|
|
storage.put_object("file", b"content")
|
|
data_received = storage.get_object("file")
|
|
|
|
assert b"content" == data_received
|
|
|
|
|
|
@pytest.fixture
|
|
def payload(payload_type):
|
|
if payload_type == "target_response_file_path":
|
|
return {
|
|
"targetFilePath": "test/file.target.json.gz",
|
|
"responseFilePath": "test/file.response.json.gz",
|
|
}
|
|
elif payload_type == "dossier_id_file_id":
|
|
return {
|
|
"dossierId": "test",
|
|
"fileId": "file",
|
|
"targetFileExtension": "target.json.gz",
|
|
"responseFileExtension": "response.json.gz",
|
|
}
|
|
elif payload_type == "target_file_dict":
|
|
return {
|
|
"targetFilePath": {"file_1": "test/file.target.json.gz", "file_2": "test/file.target.json.gz"},
|
|
"responseFilePath": "test/file.response.json.gz",
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def expected_data(payload_type):
|
|
if payload_type == "target_response_file_path":
|
|
return {"data": "success"}
|
|
elif payload_type == "dossier_id_file_id":
|
|
return {"dossierId": "test", "fileId": "file", "data": "success"}
|
|
elif payload_type == "target_file_dict":
|
|
return {"file_1": {"data": "success"}, "file_2": {"data": "success"}}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"payload_type",
|
|
[
|
|
"target_response_file_path",
|
|
"dossier_id_file_id",
|
|
"target_file_dict",
|
|
],
|
|
scope="class",
|
|
)
|
|
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
|
|
class TestDownloadAndUploadFromMessage:
|
|
def test_download_and_upload_from_message(self, storage, payload, expected_data, payload_type):
|
|
storage.clear_bucket()
|
|
|
|
upload_data = expected_data if payload_type != "target_file_dict" else expected_data["file_1"]
|
|
storage.put_object("test/file.target.json.gz", gzip.compress(json.dumps(upload_data).encode()))
|
|
|
|
data = download_data_as_specified_in_message(storage, payload)
|
|
|
|
assert data == expected_data
|
|
|
|
upload_data_as_specified_in_message(storage, payload, expected_data)
|
|
data = json.loads(gzip.decompress(storage.get_object("test/file.response.json.gz")).decode())
|
|
|
|
assert data == {**payload, "data": expected_data}
|