diff --git a/pyinfra/payload_processing/payload.py b/pyinfra/payload_processing/payload.py index 4a54b15..178effb 100644 --- a/pyinfra/payload_processing/payload.py +++ b/pyinfra/payload_processing/payload.py @@ -1,11 +1,10 @@ from dataclasses import dataclass -from functools import singledispatch -from funcy import project +from functools import singledispatch, partial +from funcy import project, complement from itertools import chain from operator import itemgetter -from typing import Union, Sized +from typing import Union, Sized, Callable, List -from kn_utils.logging import logger from pyinfra.config import Config from pyinfra.utils.file_extension_parsing import make_file_extension_parser @@ -42,78 +41,105 @@ class LegacyQueueMessagePayload(QueueMessagePayload): class QueueMessagePayloadParser: - def __init__(self, file_extension_parser, allowed_processing_parameters): - self.parse_file_extensions = file_extension_parser - self.allowed_processing_params = allowed_processing_parameters + def __init__(self, payload_matcher2parse_strategy: dict): + self.payload_matcher2parse_strategy = payload_matcher2parse_strategy def __call__(self, payload: dict) -> QueueMessagePayload: - if maybe_legacy_payload(payload): - logger.debug("Legacy payload detected.") - return self._parse_legacy_queue_message_payload(payload) - else: - return self._parse_queue_message_payload(payload) - - def _parse_queue_message_payload(self, payload: dict) -> QueueMessagePayload: - target_file_path, response_file_path = itemgetter("targetFilePath", "responseFilePath")(payload) - - target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable( - map(self.parse_file_extensions, [target_file_path, response_file_path]) - ) - - x_tenant_id = payload.get("X-TENANT-ID") - - processing_kwargs = project(payload, self.allowed_processing_params) - - return QueueMessagePayload( - target_file_path=target_file_path, - response_file_path=response_file_path, - target_file_type=target_file_type, - target_compression_type=target_compression_type, - response_file_type=response_file_type, - response_compression_type=response_compression_type, - x_tenant_id=x_tenant_id, - processing_kwargs=processing_kwargs, - ) - - def _parse_legacy_queue_message_payload(self, payload: dict) -> LegacyQueueMessagePayload: - dossier_id, file_id, target_file_extension, response_file_extension = itemgetter( - "dossierId", "fileId", "targetFileExtension", "responseFileExtension" - )(payload) - - target_file_path = f"{dossier_id}/{file_id}.{target_file_extension}" - response_file_path = f"{dossier_id}/{file_id}.{response_file_extension}" - - target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable( - map(self.parse_file_extensions, [target_file_extension, response_file_extension]) - ) - - x_tenant_id = payload.get("X-TENANT-ID") - - processing_kwargs = project(payload, self.allowed_processing_params) - - return LegacyQueueMessagePayload( - dossier_id=dossier_id, - file_id=file_id, - x_tenant_id=x_tenant_id, - target_file_extension=target_file_extension, - response_file_extension=response_file_extension, - target_file_type=target_file_type, - target_compression_type=target_compression_type, - response_file_type=response_file_type, - response_compression_type=response_compression_type, - target_file_path=target_file_path, - response_file_path=response_file_path, - processing_kwargs=processing_kwargs, - ) - - -def maybe_legacy_payload(payload: dict) -> bool: - return {"dossierId", "fileId", "targetFileExtension", "responseFileExtension"}.issubset(payload.keys()) + for payload_matcher, parse_strategy in self.payload_matcher2parse_strategy.items(): + if payload_matcher(payload): + return parse_strategy(payload) def get_queue_message_payload_parser(config: Config) -> QueueMessagePayloadParser: file_extension_parser = make_file_extension_parser(config.allowed_file_types, config.allowed_compression_types) - return QueueMessagePayloadParser(file_extension_parser, config.allowed_processing_parameters) + + payload_matcher2parse_strategy = get_payload_matcher2parse_strategy( + file_extension_parser, config.allowed_processing_parameters + ) + + return QueueMessagePayloadParser(payload_matcher2parse_strategy) + + +def get_payload_matcher2parse_strategy(parse_file_extensions: Callable, allowed_processing_parameters: List[str]): + return { + is_legacy_payload: partial( + parse_legacy_queue_message_payload, + parse_file_extensions=parse_file_extensions, + allowed_processing_parameters=allowed_processing_parameters, + ), + complement(is_legacy_payload): partial( + parse_queue_message_payload, + parse_file_extensions=parse_file_extensions, + allowed_processing_parameters=allowed_processing_parameters, + ), + } + + +def is_legacy_payload(payload: dict) -> bool: + return {"dossierId", "fileId", "targetFileExtension", "responseFileExtension"}.issubset(payload.keys()) + + +def parse_queue_message_payload( + payload: dict, + parse_file_extensions: Callable, + allowed_processing_parameters: List[str], +) -> QueueMessagePayload: + target_file_path, response_file_path = itemgetter("targetFilePath", "responseFilePath")(payload) + + target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable( + map(parse_file_extensions, [target_file_path, response_file_path]) + ) + + x_tenant_id = payload.get("X-TENANT-ID") + + processing_kwargs = project(payload, allowed_processing_parameters) + + return QueueMessagePayload( + target_file_path=target_file_path, + response_file_path=response_file_path, + target_file_type=target_file_type, + target_compression_type=target_compression_type, + response_file_type=response_file_type, + response_compression_type=response_compression_type, + x_tenant_id=x_tenant_id, + processing_kwargs=processing_kwargs, + ) + + +def parse_legacy_queue_message_payload( + payload: dict, + parse_file_extensions: Callable, + allowed_processing_parameters: List[str], +) -> LegacyQueueMessagePayload: + dossier_id, file_id, target_file_extension, response_file_extension = itemgetter( + "dossierId", "fileId", "targetFileExtension", "responseFileExtension" + )(payload) + + target_file_path = f"{dossier_id}/{file_id}.{target_file_extension}" + response_file_path = f"{dossier_id}/{file_id}.{response_file_extension}" + + target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable( + map(parse_file_extensions, [target_file_extension, response_file_extension]) + ) + + x_tenant_id = payload.get("X-TENANT-ID") + + processing_kwargs = project(payload, allowed_processing_parameters) + + return LegacyQueueMessagePayload( + dossier_id=dossier_id, + file_id=file_id, + x_tenant_id=x_tenant_id, + target_file_extension=target_file_extension, + response_file_extension=response_file_extension, + target_file_type=target_file_type, + target_compression_type=target_compression_type, + response_file_type=response_file_type, + response_compression_type=response_compression_type, + target_file_path=target_file_path, + response_file_path=response_file_path, + processing_kwargs=processing_kwargs, + ) @singledispatch diff --git a/pyinfra/payload_processing/processor.py b/pyinfra/payload_processing/processor.py index c125857..2670af8 100644 --- a/pyinfra/payload_processing/processor.py +++ b/pyinfra/payload_processing/processor.py @@ -12,7 +12,7 @@ from pyinfra.payload_processing.payload import ( QueueMessagePayload, ) from pyinfra.storage.storage import make_downloader, make_uploader -from pyinfra.storage.storage_manager import StorageManager +from pyinfra.storage.storage_provider import StorageProvider logger = getLogger() logger.setLevel(get_config().logging_level_root) @@ -21,21 +21,21 @@ logger.setLevel(get_config().logging_level_root) class PayloadProcessor: def __init__( self, - storage_manager: StorageManager, + storage_provider: StorageProvider, payload_parser: QueueMessagePayloadParser, data_processor: Callable, ): """Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps. Args: - storage_manager: Storage manager that connects to the storage, using the tenant id if provided + storage_provider: Storage manager that connects to the storage, using the tenant id if provided payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object data_processor: The analysis function to be called with the downloaded file NOTE: The result of the analysis function has to be an instance of `Sized`, e.g. a dict or a list to be able to upload it and to be able to monitor the processing time. """ self.parse_payload = payload_parser - self.connect_storage = storage_manager + self.provide_storage = storage_provider self.process_data = data_processor def __call__(self, queue_message_payload: dict) -> dict: @@ -65,13 +65,13 @@ class PayloadProcessor: logger.info(f"Processing {payload.__class__.__name__} ...") logger.debug(f"Payload contents: {asdict(payload)} ...") - storage, storage_info = self.connect_storage(payload.x_tenant_id) + storage, storage_info = self.provide_storage(payload.x_tenant_id) download_file_to_process = make_downloader( - storage, storage_info.bucket, payload.target_file_type, payload.target_compression_type + storage, storage_info.bucket_name, payload.target_file_type, payload.target_compression_type ) upload_processing_result = make_uploader( - storage, storage_info.bucket, payload.response_file_type, payload.response_compression_type + storage, storage_info.bucket_name, payload.response_file_type, payload.response_compression_type ) data = download_file_to_process(payload.target_file_path) @@ -87,14 +87,14 @@ def make_payload_processor(data_processor: Callable, config: Config = None) -> P """Creates a payload processor.""" config = config or get_config() - storage_manager = StorageManager(config) + storage_provider = StorageProvider(config) monitor = get_monitor_from_config(config) payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config) data_processor = monitor(data_processor) return PayloadProcessor( - storage_manager, + storage_provider, payload_parser, data_processor, ) diff --git a/pyinfra/storage/storage_manager.py b/pyinfra/storage/storage_provider.py similarity index 89% rename from pyinfra/storage/storage_manager.py rename to pyinfra/storage/storage_provider.py index b0e2338..e5e6900 100644 --- a/pyinfra/storage/storage_manager.py +++ b/pyinfra/storage/storage_provider.py @@ -4,11 +4,12 @@ from kn_utils.logging import logger from typing import Tuple from pyinfra.config import Config -from pyinfra.storage.storage_info import get_storage_info_from_config, get_storage_info_from_endpoint, StorageInfo +from pyinfra.storage.storage_info import get_storage_info_from_config, get_storage_info_from_endpoint, StorageInfo, \ + get_storage_from_storage_info from pyinfra.storage.storages.interface import Storage -class StorageManager: +class StorageProvider: def __init__(self, config: Config): self.config = config self.default_storage_info: StorageInfo = get_storage_info_from_config(config) @@ -25,7 +26,7 @@ class StorageManager: @lru_cache(maxsize=32) def connect(self, x_tenant_id=None) -> Tuple[Storage, StorageInfo]: storage_info = self._get_storage_info(x_tenant_id) - storage_connection = storage_info.get_storage() + storage_connection = get_storage_from_storage_info(storage_info) return storage_connection, storage_info def _get_storage_info(self, x_tenant_id=None):