diff --git a/pyinfra/payload_processing/payload.py b/pyinfra/payload_processing/payload.py index 4381a8b..ef5f394 100644 --- a/pyinfra/payload_processing/payload.py +++ b/pyinfra/payload_processing/payload.py @@ -3,6 +3,8 @@ from itertools import chain from operator import itemgetter from typing import Union, Sized +from funcy import project + from pyinfra.config import Config from pyinfra.utils.file_extension_parsing import make_file_extension_parser @@ -24,10 +26,13 @@ class QueueMessagePayload: target_file_name: str response_file_name: str + processing_kwargs: dict + class QueueMessagePayloadParser: - def __init__(self, file_extension_parser): + def __init__(self, file_extension_parser, allowed_processing_args=("operation",)): self.parse_file_extensions = file_extension_parser + self.allowed_args = allowed_processing_args def __call__(self, payload: dict) -> QueueMessagePayload: """Translate the queue message payload to the internal QueueMessagePayload object.""" @@ -46,6 +51,8 @@ class QueueMessagePayloadParser: target_file_name = f"{dossier_id}/{file_id}.{target_file_extension}" response_file_name = f"{dossier_id}/{file_id}.{response_file_extension}" + processing_kwargs = project(payload, self.allowed_args) + return QueueMessagePayload( dossier_id=dossier_id, file_id=file_id, @@ -58,6 +65,7 @@ class QueueMessagePayloadParser: response_compression_type=response_compression_type, target_file_name=target_file_name, response_file_name=response_file_name, + processing_kwargs=processing_kwargs, ) diff --git a/pyinfra/payload_processing/processor.py b/pyinfra/payload_processing/processor.py index 0202a0d..d6dccac 100644 --- a/pyinfra/payload_processing/processor.py +++ b/pyinfra/payload_processing/processor.py @@ -1,9 +1,7 @@ import logging from dataclasses import asdict from functools import partial -from typing import Callable, Union, List - -from funcy import compose +from typing import Callable, List from pyinfra.config import get_config, Config from pyinfra.payload_processing.monitor import get_monitor_from_config @@ -13,12 +11,12 @@ from pyinfra.payload_processing.payload import ( QueueMessagePayloadFormatter, get_queue_message_payload_formatter, ) -from pyinfra.storage.storage import make_downloader, make_uploader, get_storage_from_storage_info +from pyinfra.storage.storage import make_downloader, make_uploader from pyinfra.storage.storage_info import ( - AzureStorageInfo, - S3StorageInfo, get_storage_info_from_config, get_storage_info_from_endpoint, + StorageInfo, + DefaultStorageCache, ) logger = logging.getLogger() @@ -28,7 +26,8 @@ logger.setLevel(get_config().logging_level_root) class PayloadProcessor: def __init__( self, - default_storage_info: Union[AzureStorageInfo, S3StorageInfo], + default_storage_info: StorageInfo, + storage_cache: DefaultStorageCache, get_storage_info_from_tenant_id, payload_parser: QueueMessagePayloadParser, payload_formatter: QueueMessagePayloadFormatter, @@ -53,8 +52,7 @@ class PayloadProcessor: self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id self.default_storage_info = default_storage_info - # TODO: use lru-dict - self.storages = {} + self.storages = storage_cache def __call__(self, queue_message_payload: dict) -> dict: """Processes a queue message payload. @@ -79,8 +77,8 @@ class PayloadProcessor: logger.info(f"Processing {asdict(payload)} ...") storage_info = self._get_storage_info(payload.x_tenant_id) + storage = self.storages[storage_info] bucket = storage_info.bucket_name - storage = self._get_storage(storage_info) download_file_to_process = make_downloader( storage, bucket, payload.target_file_type, payload.target_compression_type @@ -90,11 +88,11 @@ class PayloadProcessor: ) format_result_for_storage = partial(self.format_result_for_storage, payload) - processing_pipeline = compose(format_result_for_storage, self.process_data, download_file_to_process) + data = download_file_to_process(payload.target_file_name) + result: List[dict] = self.process_data(data, **payload.processing_kwargs) + formatted_result = format_result_for_storage(result) - result: List[dict] = processing_pipeline(payload.target_file_name) - - upload_processing_result(payload.response_file_name, result) + upload_processing_result(payload.response_file_name, formatted_result) return self.format_to_queue_message_response_body(payload) @@ -103,20 +101,13 @@ class PayloadProcessor: return self.get_storage_info_from_tenant_id(x_tenant_id) return self.default_storage_info - def _get_storage(self, storage_info): - if storage_info in self.storages: - return self.storages[storage_info] - else: - storage = get_storage_from_storage_info(storage_info) - self.storages[storage_info] = storage - return storage - -def make_payload_processor(data_processor: Callable, config: Union[None, Config] = None) -> PayloadProcessor: +def make_payload_processor(data_processor: Callable, config: Config = None) -> PayloadProcessor: """Produces payload processor for queue manager.""" config = config or get_config() - default_storage_info: Union[AzureStorageInfo, S3StorageInfo] = get_storage_info_from_config(config) + storage_cache: DefaultStorageCache = DefaultStorageCache(max_size=10) + default_storage_info: StorageInfo = get_storage_info_from_config(config) get_storage_info_from_tenant_id = partial( get_storage_info_from_endpoint, config.persistence_service_public_key, @@ -130,6 +121,7 @@ def make_payload_processor(data_processor: Callable, config: Union[None, Config] return PayloadProcessor( default_storage_info, + storage_cache, get_storage_info_from_tenant_id, payload_parser, payload_formatter, diff --git a/pyinfra/storage/storage.py b/pyinfra/storage/storage.py index 5e14bd5..bd849d8 100644 --- a/pyinfra/storage/storage.py +++ b/pyinfra/storage/storage.py @@ -1,16 +1,11 @@ from functools import lru_cache, partial -from typing import Callable, Union +from typing import Callable -from azure.storage.blob import BlobServiceClient from funcy import compose -from minio import Minio from pyinfra.config import Config -from pyinfra.exception import UnknownStorageBackend -from pyinfra.storage.storage_info import AzureStorageInfo, S3StorageInfo, get_storage_info_from_config -from pyinfra.storage.storages.azure import AzureStorage +from pyinfra.storage.storage_info import get_storage_info_from_config, get_storage_from_storage_info from pyinfra.storage.storages.interface import Storage -from pyinfra.storage.storages.s3 import S3Storage from pyinfra.utils.compressing import get_decompressor, get_compressor from pyinfra.utils.encoding import get_decoder, get_encoder @@ -23,23 +18,6 @@ def get_storage_from_config(config: Config) -> Storage: return storage -def get_storage_from_storage_info(storage_info: Union[AzureStorageInfo, S3StorageInfo]) -> Storage: - if isinstance(storage_info, AzureStorageInfo): - return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string)) - elif isinstance(storage_info, S3StorageInfo): - return S3Storage( - Minio( - secure=storage_info.secure, - endpoint=storage_info.endpoint, - access_key=storage_info.access_key, - secret_key=storage_info.secret_key, - region=storage_info.region, - ) - ) - else: - raise UnknownStorageBackend() - - def verify_existence(storage: Storage, bucket: str, file_name: str) -> str: if not storage.exists(bucket, file_name): raise FileNotFoundError(f"{file_name=} name not found on storage in {bucket=}.") diff --git a/pyinfra/storage/storage_info.py b/pyinfra/storage/storage_info.py index 80d0cdd..362872f 100644 --- a/pyinfra/storage/storage_info.py +++ b/pyinfra/storage/storage_info.py @@ -1,45 +1,113 @@ from dataclasses import dataclass -from typing import Union +from typing import Callable import requests +from azure.storage.blob import BlobServiceClient +from minio import Minio from pyinfra.config import Config from pyinfra.exception import UnknownStorageBackend +from pyinfra.storage.storages.azure import AzureStorage +from pyinfra.storage.storages.interface import Storage +from pyinfra.storage.storages.s3 import S3Storage from pyinfra.utils.cipher import decrypt from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint @dataclass(frozen=True) -class AzureStorageInfo: - connection_string: str +class StorageInfo: bucket_name: str + +@dataclass(frozen=True) +class AzureStorageInfo(StorageInfo): + connection_string: str + def __hash__(self): return hash(self.connection_string) + def __eq__(self, other): + if not isinstance(other, AzureStorageInfo): + return False + return self.connection_string == other.connection_string + @dataclass(frozen=True) -class S3StorageInfo: +class S3StorageInfo(StorageInfo): secure: bool endpoint: str access_key: str secret_key: str region: str - bucket_name: str def __hash__(self): return hash((self.secure, self.endpoint, self.access_key, self.secret_key, self.region)) + def __eq__(self, other): + if not isinstance(other, S3StorageInfo): + return False + return ( + self.secure == other.secure + and self.endpoint == other.endpoint + and self.access_key == other.access_key + and self.secret_key == other.secret_key + and self.region == other.region + ) -def get_storage_info_from_endpoint( - public_key: str, endpoint: str, x_tenant_id: str -) -> Union[AzureStorageInfo, S3StorageInfo]: - # FIXME: parameterize port, host and public_key - public_key = "redaction" + +def get_storage_from_storage_info(storage_info: StorageInfo) -> Storage: + if isinstance(storage_info, AzureStorageInfo): + return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string)) + elif isinstance(storage_info, S3StorageInfo): + return S3Storage( + Minio( + secure=storage_info.secure, + endpoint=storage_info.endpoint, + access_key=storage_info.access_key, + secret_key=storage_info.secret_key, + region=storage_info.region, + ) + ) + else: + raise UnknownStorageBackend() + + +class DefaultStorageCache(dict): + def __init__(self, max_size=None, get_value: Callable = get_storage_from_storage_info, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_size = max_size + self.get_value = get_value + self.access_history = [] + + def __getitem__(self, key: StorageInfo): + self.access_history.append(key) + value = super().get(key) + if not value: + self.__setitem__(key, value) + value = self[key] + + return value + + def __setitem__(self, key: StorageInfo, value: Storage = None): + self._delete_oldest_key_if_max_size_breached() + value = value or self.get_value(key) + super().__setitem__(key, value) + self.access_history.append(key) + + def set_key(self, key: StorageInfo): + self.__setitem__(key) + + def _delete_oldest_key_if_max_size_breached(self): + if len(self) >= self.max_size: + oldest_key = self.access_history.pop(0) + del self[oldest_key] + + +def get_storage_info_from_endpoint(public_key: str, endpoint: str, x_tenant_id: str) -> StorageInfo: resp = requests.get(f"{endpoint}/{x_tenant_id}").json() maybe_azure = resp.get("azureStorageConnection") - maybe_s3 = resp.get("azureStorageConnection") + maybe_s3 = resp.get("s3StorageConnection") assert not (maybe_azure and maybe_s3) if maybe_azure: @@ -66,7 +134,7 @@ def get_storage_info_from_endpoint( return storage_info -def get_storage_info_from_config(config: Config) -> Union[AzureStorageInfo, S3StorageInfo]: +def get_storage_info_from_config(config: Config) -> StorageInfo: if config.storage_backend == "s3": storage_info = S3StorageInfo( secure=config.storage_secure_connection, diff --git a/tests/payload_parsing_test.py b/tests/payload_parsing_test.py index 06d3510..303f301 100644 --- a/tests/payload_parsing_test.py +++ b/tests/payload_parsing_test.py @@ -21,6 +21,7 @@ def expected_parsed_payload(x_tenant_id): response_compression_type="gz", target_file_name="test/test.json.gz", response_file_name="test/test.json.gz", + processing_kwargs={}, ) diff --git a/tests/storage_cache_test.py b/tests/storage_cache_test.py new file mode 100644 index 0000000..21152e6 --- /dev/null +++ b/tests/storage_cache_test.py @@ -0,0 +1,47 @@ +from dataclasses import asdict + +import pytest + +from pyinfra.storage.storage_info import AzureStorageInfo, StorageInfo, DefaultStorageCache + + +@pytest.fixture +def storage_infos(): + return [ + AzureStorageInfo(connection_string="first", bucket_name="Thorsten"), + AzureStorageInfo(connection_string="first", bucket_name="Klaus"), + AzureStorageInfo(connection_string="second", bucket_name="Thorsten"), + AzureStorageInfo(connection_string="third", bucket_name="Klaus"), + ] + + +@pytest.fixture +def storage_cache(max_size): + def get_connection_from_storage_info_mock(storage_info: StorageInfo): + return asdict(storage_info) + + return DefaultStorageCache(max_size=max_size, get_value=get_connection_from_storage_info_mock) + + +class TestStorageCache: + def test_same_storage_info_has_same_hash(self, storage_infos): + assert storage_infos[0].__hash__() == storage_infos[1].__hash__() + + @pytest.mark.parametrize("max_size", [4]) + def test_same_connection_different_bucket_does_not_create_new_connection(self, storage_cache, storage_infos): + value = storage_cache[storage_infos[0]] + assert value == asdict((storage_infos[0])) + + value = storage_cache[storage_infos[1]] + assert value == asdict((storage_infos[0])) + + @pytest.mark.parametrize("max_size", [2]) + def test_max_size_breached_removes_oldest_key(self, storage_cache, storage_infos): + storage_cache.set_key(storage_infos[0]) + storage_cache.set_key(storage_infos[2]) + storage_cache.set_key(storage_infos[3]) + + assert len(storage_cache) == 2 + assert storage_infos[0] not in storage_cache + assert storage_infos[2] in storage_cache + assert storage_infos[3] in storage_cache