import gzip import json from time import sleep import pytest from fastapi import FastAPI from pyinfra.storage.connection import get_storage_for_tenant from pyinfra.storage.utils import ( download_data_bytes_as_specified_in_message, upload_data_as_specified_in_message, ) from pyinfra.utils.cipher import encrypt from pyinfra.webserver.utils import create_webserver_thread @pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class") class TestStorage: def test_clearing_bucket_yields_empty_bucket(self, storage): storage.clear_bucket() data_received = storage.get_all_objects() assert not {*data_received} def test_getting_object_put_in_bucket_is_object(self, storage): storage.clear_bucket() storage.put_object("file", b"content") data_received = storage.get_object("file") assert b"content" == data_received def test_object_put_in_bucket_exists_on_storage(self, storage): storage.clear_bucket() storage.put_object("file", b"content") assert storage.exists("file") def test_getting_nested_object_put_in_bucket_is_nested_object(self, storage): storage.clear_bucket() storage.put_object("folder/file", b"content") data_received = storage.get_object("folder/file") assert b"content" == data_received def test_getting_objects_put_in_bucket_are_objects(self, storage): storage.clear_bucket() storage.put_object("file1", b"content 1") storage.put_object("folder/file2", b"content 2") data_received = storage.get_all_objects() assert {b"content 1", b"content 2"} == {*data_received} def test_make_bucket_produces_bucket(self, storage): storage.clear_bucket() storage.make_bucket() assert storage.has_bucket() def test_listing_bucket_files_yields_all_files_in_bucket(self, storage): storage.clear_bucket() storage.put_object("file1", b"content 1") storage.put_object("file2", b"content 2") full_names_received = storage.get_all_object_names() assert {(storage.bucket, "file1"), (storage.bucket, "file2")} == {*full_names_received} def test_data_loading_failure_raised_if_object_not_present(self, storage): storage.clear_bucket() with pytest.raises(Exception): storage.get_object("folder/file") @pytest.fixture(scope="class") def tenant_server_mock(settings, tenant_server_host, tenant_server_port): app = FastAPI() @app.get("/azure_tenant") def get_azure_storage_info(): return { "azureStorageConnection": { "connectionString": encrypt( settings.storage.tenant_server.public_key, settings.storage.azure.connection_string ), "containerName": settings.storage.azure.container, } } @app.get("/s3_tenant") def get_s3_storage_info(): return { "s3StorageConnection": { "endpoint": settings.storage.s3.endpoint, "key": settings.storage.s3.key, "secret": encrypt(settings.storage.tenant_server.public_key, settings.storage.s3.secret), "region": settings.storage.s3.region, "bucketName": settings.storage.s3.bucket, } } thread = create_webserver_thread(app, tenant_server_port, tenant_server_host) thread.daemon = True thread.start() sleep(1) yield thread.join(timeout=1) @pytest.mark.parametrize("tenant_id", ["azure_tenant", "s3_tenant"], scope="class") @pytest.mark.parametrize("tenant_server_host", ["localhost"], scope="class") @pytest.mark.parametrize("tenant_server_port", [8000], scope="class") class TestMultiTenantStorage: def test_storage_connection_from_tenant_id( self, tenant_id, tenant_server_mock, settings, tenant_server_host, tenant_server_port ): settings["storage"]["tenant_server"]["endpoint"] = f"http://{tenant_server_host}:{tenant_server_port}" storage = get_storage_for_tenant( tenant_id, settings["storage"]["tenant_server"]["endpoint"], settings["storage"]["tenant_server"]["public_key"], ) storage.put_object("file", b"content") data_received = storage.get_object("file") assert b"content" == data_received @pytest.fixture def payload(payload_type): if payload_type == "target_response_file_path": return { "targetFilePath": "test/file.target.json.gz", "responseFilePath": "test/file.response.json.gz", } elif payload_type == "dossier_id_file_id": return { "dossierId": "test", "fileId": "file", "targetFileExtension": "target.json.gz", "responseFileExtension": "response.json.gz", } elif payload_type == "target_file_dict": return { "targetFilePath": {"file_1": "test/file.target.json.gz", "file_2": "test/file.target.json.gz"}, "responseFilePath": "test/file.response.json.gz", } @pytest.mark.parametrize( "payload_type", [ "target_response_file_path", "dossier_id_file_id", "target_file_dict", ], scope="class", ) @pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class") class TestDownloadAndUploadFromMessage: def test_download_and_upload_from_message(self, storage, payload, payload_type): storage.clear_bucket() result = {"process_result": "success"} storage_data = {**payload, "data": result} packed_data = gzip.compress(json.dumps(storage_data).encode()) storage.put_object("test/file.target.json.gz", packed_data) _ = download_data_bytes_as_specified_in_message(storage, payload) upload_data_as_specified_in_message(storage, payload, result) data = json.loads(gzip.decompress(storage.get_object("test/file.response.json.gz")).decode()) assert data == storage_data