diff --git a/pyinfra/payload_processing/processor.py b/pyinfra/payload_processing/processor.py index 5f98d22..19e1823 100644 --- a/pyinfra/payload_processing/processor.py +++ b/pyinfra/payload_processing/processor.py @@ -1,6 +1,5 @@ import logging from dataclasses import asdict -from functools import partial from typing import Callable, List from pyinfra.config import get_config, Config @@ -13,12 +12,7 @@ from pyinfra.payload_processing.payload import ( QueueMessagePayload, ) from pyinfra.storage.storage import make_downloader, make_uploader -from pyinfra.storage.storage_info import ( - get_storage_info_from_config, - get_storage_info_from_endpoint, - StorageInfo, - get_storage_from_storage_info, -) +from pyinfra.storage.storage_manager import StorageManager logger = logging.getLogger() logger.setLevel(get_config().logging_level_root) @@ -27,28 +21,23 @@ logger.setLevel(get_config().logging_level_root) class PayloadProcessor: def __init__( self, - default_storage_info: StorageInfo, - get_storage_info_from_tenant_id, + storage_manager: StorageManager, payload_parser: QueueMessagePayloadParser, data_processor: Callable, ): """Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps. Args: - default_storage_info: The default storage info used to create the storage connection. This is only used if - x_tenant_id is not provided in the queue payload. - get_storage_info_from_tenant_id: Callable to acquire storage info from a given tenant id. + storage_manager: Storage manager that connects to the storage, using the tenant id if provided payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object data_processor: The analysis function to be called with the downloaded file NOTE: The result of the analysis function has to be an instance of `Sized`, e.g. a dict or a list to be able to upload it and to be able to monitor the processing time. """ self.parse_payload = payload_parser + self.connect_storage = storage_manager self.process_data = data_processor - self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id - self.default_storage_info = default_storage_info - def __call__(self, queue_message_payload: dict) -> dict: """Processes a queue message payload. @@ -60,29 +49,29 @@ class PayloadProcessor: Args: queue_message_payload: The payload of a queue message. The payload is expected to be a dict with the - following keys: dossierId, fileId, targetFileExtension, responseFileExtension + following keys: + targetFilePath, responseFilePath + OR + dossierId, fileId, targetFileExtension, responseFileExtension Returns: - The payload for a response queue message. The payload is a dict with the following keys: dossierId, fileId + The payload for a response queue message, containing only the request payload. """ return self._process(queue_message_payload) def _process(self, queue_message_payload: dict) -> dict: - logger.info(f"Processing Payload ...") - payload: QueueMessagePayload = self.parse_payload(queue_message_payload) - logger.debug(f"Payload: {asdict(payload)} ...") + logger.info(f"Processing {payload.__class__.__name__} ...") + logger.debug(f"Payload contents: {asdict(payload)} ...") - storage_info = self._get_storage_info(payload.x_tenant_id) - storage = get_storage_from_storage_info(storage_info) - bucket = storage_info.bucket_name + storage, storage_info = self.connect_storage(payload.x_tenant_id) download_file_to_process = make_downloader( - storage, bucket, payload.target_file_type, payload.target_compression_type + storage, storage_info.bucket, payload.target_file_type, payload.target_compression_type ) upload_processing_result = make_uploader( - storage, bucket, payload.response_file_type, payload.response_compression_type + storage, storage_info.bucket, payload.response_file_type, payload.response_compression_type ) data = download_file_to_process(payload.target_file_path) @@ -93,36 +82,19 @@ class PayloadProcessor: return format_to_queue_message_response_body(payload) - def _get_storage_info(self, x_tenant_id=None): - if x_tenant_id: - storage_info = self.get_storage_info_from_tenant_id(x_tenant_id) - logger.info(f"Received {storage_info.__class__.__name__} for {x_tenant_id} from endpoint.") - logger.debug(f"{asdict(storage_info)}") - else: - storage_info = self.default_storage_info - logger.info(f"Using local default {storage_info.__class__.__name__} for {x_tenant_id}.") - logger.debug(f"{asdict(storage_info)}") - return storage_info - def make_payload_processor(data_processor: Callable, config: Config = None) -> PayloadProcessor: - """Produces payload processor for queue manager.""" + """Creates a payload processor.""" config = config or get_config() - default_storage_info: StorageInfo = get_storage_info_from_config(config) - get_storage_info_from_tenant_id = partial( - get_storage_info_from_endpoint, - config.tenant_decryption_public_key, - config.tenant_endpoint, - ) + storage_manager = StorageManager(config) monitor = get_monitor_from_config(config) payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config) data_processor = monitor(data_processor) return PayloadProcessor( - default_storage_info, - get_storage_info_from_tenant_id, + storage_manager, payload_parser, data_processor, ) diff --git a/pyinfra/storage/storage_manager.py b/pyinfra/storage/storage_manager.py new file mode 100644 index 0000000..71d2519 --- /dev/null +++ b/pyinfra/storage/storage_manager.py @@ -0,0 +1,41 @@ +from dataclasses import asdict +from functools import partial, lru_cache +from typing import Tuple + +from pyinfra import logger +from pyinfra.config import Config +from pyinfra.storage.storage_info import get_storage_info_from_config, get_storage_info_from_endpoint, StorageInfo +from pyinfra.storage.storages.interface import Storage + + +class StorageManager: + def __init__(self, config: Config): + self.config = config + self.default_storage_info: StorageInfo = get_storage_info_from_config(config) + + self.get_storage_info_from_tenant_id = partial( + get_storage_info_from_endpoint, + config.tenant_decryption_public_key, + config.tenant_endpoint, + ) + + def __call__(self, *args, **kwargs): + return self.connect(*args, **kwargs) + + @lru_cache(maxsize=32) + def connect(self, x_tenant_id=None) -> Tuple[Storage, StorageInfo]: + storage_info = self._get_storage_info(x_tenant_id) + storage_connection = storage_info.get_storage() + return storage_connection, storage_info + + def _get_storage_info(self, x_tenant_id=None): + if x_tenant_id: + storage_info = self.get_storage_info_from_tenant_id(x_tenant_id) + logger.debug(f"Received {storage_info.__class__.__name__} for {x_tenant_id} from endpoint.") + logger.trace(f"{asdict(storage_info)}") + else: + storage_info = self.default_storage_info + logger.debug(f"Using local default {storage_info.__class__.__name__} for {x_tenant_id}.") + logger.trace(f"{asdict(storage_info)}") + + return storage_info