From b810449bbaf3d93c5ae900ab567c17f486ce85a3 Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Thu, 18 Apr 2024 16:19:24 +0200 Subject: [PATCH] feat: add multiple file download The download function is now overloaded and additionlly supports a dict with file paths as values, in addition to the present string as file path. The data is forwarded as dict of the same structure in the first case. --- pyinfra/storage/utils.py | 39 ++++++++++++++++++++++++--------- tests/unit_test/storage_test.py | 38 +++++++++++++++++++++++++------- 2 files changed, 59 insertions(+), 18 deletions(-) diff --git a/pyinfra/storage/utils.py b/pyinfra/storage/utils.py index 34f9764..36b5b81 100644 --- a/pyinfra/storage/utils.py +++ b/pyinfra/storage/utils.py @@ -1,5 +1,6 @@ import gzip import json +from functools import singledispatch from typing import Union from kn_utils.logging import logger @@ -29,7 +30,7 @@ class DossierIdFileIdUploadPayload(BaseModel): class TargetResponseFilePathDownloadPayload(BaseModel): - targetFilePath: str + targetFilePath: Union[str, dict] class TargetResponseFilePathUploadPayload(BaseModel): @@ -38,7 +39,8 @@ class TargetResponseFilePathUploadPayload(BaseModel): def download_data_as_specified_in_message(storage: Storage, raw_payload: dict) -> Union[dict, bytes]: """Convenience function to download a file specified in a message payload. - Supports both legacy and new payload formats. + Supports both legacy and new payload formats. Also supports downloading multiple files at once, which should + be specified in a dictionary under the 'targetFilePath' key with the file path as value. If the content is compressed with gzip (.gz), it will be decompressed (-> bytes). If the content is a json file, it will be decoded (-> dict). @@ -60,18 +62,35 @@ def download_data_as_specified_in_message(storage: Storage, raw_payload: dict) - except ValidationError: raise ValueError("No download file path found in payload, nothing to download.") - if not storage.exists(payload.targetFilePath): - raise FileNotFoundError(f"File '{payload.targetFilePath}' does not exist in storage.") - - data = storage.get_object(payload.targetFilePath) - - data = gzip.decompress(data) if ".gz" in payload.targetFilePath else data - data = json.loads(data.decode("utf-8")) if ".json" in payload.targetFilePath else data - logger.info(f"Downloaded {payload.targetFilePath} from storage.") + data = _download(payload.targetFilePath, storage) return data +@singledispatch +def _download(file_path_or_file_path_dict: Union[str, dict], storage: Storage) -> Union[dict, bytes]: + pass + + +@_download.register(str) +def _download_single_file(file_path: str, storage: Storage) -> bytes: + if not storage.exists(file_path): + raise FileNotFoundError(f"File '{file_path}' does not exist in storage.") + + data = storage.get_object(file_path) + + data = gzip.decompress(data) if ".gz" in file_path else data + data = json.loads(data.decode("utf-8")) if ".json" in file_path else data + logger.info(f"Downloaded {file_path} from storage.") + + return data + + +@_download.register(dict) +def _download_multiple_files(file_path_dict: dict, storage: Storage) -> dict: + return {key: _download(value, storage) for key, value in file_path_dict.items()} + + def upload_data_as_specified_in_message(storage: Storage, raw_payload: dict, data): """Convenience function to upload a file specified in a message payload. For now, only json serializable data is supported. The storage json consists of the raw_payload, which is extended with a 'data' key, containing the diff --git a/tests/unit_test/storage_test.py b/tests/unit_test/storage_test.py index dd3ea25..c7d6d40 100644 --- a/tests/unit_test/storage_test.py +++ b/tests/unit_test/storage_test.py @@ -132,23 +132,45 @@ def payload(payload_type): "targetFileExtension": "target.json.gz", "responseFileExtension": "response.json.gz", } + elif payload_type == "target_file_dict": + return { + "targetFilePath": {"file_1": "test/file.target.json.gz", "file_2": "test/file.target.json.gz"}, + "responseFilePath": "test/file.response.json.gz", + } -@pytest.mark.parametrize("payload_type", ["target_response_file_path", "dossier_id_file_id"], scope="class") +@pytest.fixture +def expected_data(payload_type): + if payload_type == "target_response_file_path": + return {"data": "success"} + elif payload_type == "dossier_id_file_id": + return {"dossierId": "test", "fileId": "file", "data": "success"} + elif payload_type == "target_file_dict": + return {"file_1": {"data": "success"}, "file_2": {"data": "success"}} + + +@pytest.mark.parametrize( + "payload_type", + [ + "target_response_file_path", + "dossier_id_file_id", + "target_file_dict", + ], + scope="class", +) @pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class") class TestDownloadAndUploadFromMessage: - def test_download_and_upload_from_message(self, storage, payload): + def test_download_and_upload_from_message(self, storage, payload, expected_data, payload_type): storage.clear_bucket() - input_data = {"data": "success"} - - storage.put_object("test/file.target.json.gz", gzip.compress(json.dumps(input_data).encode())) + upload_data = expected_data if payload_type != "target_file_dict" else expected_data["file_1"] + storage.put_object("test/file.target.json.gz", gzip.compress(json.dumps(upload_data).encode())) data = download_data_as_specified_in_message(storage, payload) - assert data == input_data + assert data == expected_data - upload_data_as_specified_in_message(storage, payload, input_data) + upload_data_as_specified_in_message(storage, payload, expected_data) data = json.loads(gzip.decompress(storage.get_object("test/file.response.json.gz")).decode()) - assert data == {**payload, "data": input_data} + assert data == {**payload, "data": expected_data}