From 48d74b430779c8778d259ae2cde71f3b5c3eb2aa Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Fri, 18 Aug 2023 11:15:47 +0200 Subject: [PATCH] Add support for absolute file paths Introduces new payload parsing logic to be able to process absolute file paths. The queue message is expected to contain the keys "targetFilePath" and "responseFilePath". To ensure backward-compatibility, the legacy "dossierId", "fileId" messages are still supported. --- pyinfra/config.py | 2 + pyinfra/payload_processing/payload.py | 143 +++++++++++++++++------- pyinfra/payload_processing/processor.py | 27 ++--- tests/payload_parsing_test.py | 10 +- tests/payload_processor_test.py | 4 +- 5 files changed, 124 insertions(+), 62 deletions(-) diff --git a/pyinfra/config.py b/pyinfra/config.py index fe2d72e..396e20d 100644 --- a/pyinfra/config.py +++ b/pyinfra/config.py @@ -92,6 +92,8 @@ class Config: self.allowed_file_types = ["json", "pdf"] self.allowed_compression_types = ["gz"] + self.allowed_processing_parameters = ["operation"] + # config for x-tenant-endpoint to receive storage connection information per tenant self.tenant_decryption_public_key = read_from_environment("TENANT_PUBLIC_KEY", "redaction") self.tenant_endpoint = read_from_environment("TENANT_ENDPOINT", "http://tenant-user-management:8081/internal-api/tenants") diff --git a/pyinfra/payload_processing/payload.py b/pyinfra/payload_processing/payload.py index ef5f394..5557d43 100644 --- a/pyinfra/payload_processing/payload.py +++ b/pyinfra/payload_processing/payload.py @@ -1,59 +1,97 @@ from dataclasses import dataclass +from functools import singledispatch +from funcy import project from itertools import chain from operator import itemgetter from typing import Union, Sized -from funcy import project - +from pyinfra import logger from pyinfra.config import Config from pyinfra.utils.file_extension_parsing import make_file_extension_parser @dataclass class QueueMessagePayload: - dossier_id: str - file_id: str - x_tenant_id: Union[str, None] + """Default one-to-one payload, where the message contains the absolute file paths for the target and response files, + that have to be acquired from the storage.""" - target_file_extension: str - response_file_extension: str + target_file_path: str + response_file_path: str target_file_type: Union[str, None] target_compression_type: Union[str, None] response_file_type: Union[str, None] response_compression_type: Union[str, None] - target_file_name: str - response_file_name: str + x_tenant_id: Union[str, None] processing_kwargs: dict +@dataclass +class LegacyQueueMessagePayload(QueueMessagePayload): + """Legacy one-to-one payload, where the message contains the dossier and file ids, and the file extensions that have + to be used to construct the absolute file paths for the target and response files, that have to be acquired from the + storage.""" + + dossier_id: str + file_id: str + + target_file_extension: str + response_file_extension: str + + class QueueMessagePayloadParser: - def __init__(self, file_extension_parser, allowed_processing_args=("operation",)): + def __init__(self, file_extension_parser, allowed_processing_parameters): self.parse_file_extensions = file_extension_parser - self.allowed_args = allowed_processing_args + self.allowed_processing_params = allowed_processing_parameters def __call__(self, payload: dict) -> QueueMessagePayload: - """Translate the queue message payload to the internal QueueMessagePayload object.""" - return self._parse_queue_message_payload(payload) + if maybe_legacy_payload(payload): + logger.debug("Legacy payload detected.") + return self._parse_legacy_queue_message_payload(payload) + else: + return self._parse_queue_message_payload(payload) def _parse_queue_message_payload(self, payload: dict) -> QueueMessagePayload: + target_file_path, response_file_path = itemgetter("targetFilePath", "responseFilePath")(payload) + + target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable( + map(self.parse_file_extensions, [target_file_path, response_file_path]) + ) + + x_tenant_id = payload.get("X-TENANT-ID") + + processing_kwargs = project(payload, self.allowed_processing_params) + + return QueueMessagePayload( + target_file_path=target_file_path, + response_file_path=response_file_path, + target_file_type=target_file_type, + target_compression_type=target_compression_type, + response_file_type=response_file_type, + response_compression_type=response_compression_type, + x_tenant_id=x_tenant_id, + processing_kwargs=processing_kwargs, + ) + + def _parse_legacy_queue_message_payload(self, payload: dict) -> LegacyQueueMessagePayload: dossier_id, file_id, target_file_extension, response_file_extension = itemgetter( "dossierId", "fileId", "targetFileExtension", "responseFileExtension" )(payload) - x_tenant_id = payload.get("X-TENANT-ID") + + target_file_path = f"{dossier_id}/{file_id}.{target_file_extension}" + response_file_path = f"{dossier_id}/{file_id}.{response_file_extension}" target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable( map(self.parse_file_extensions, [target_file_extension, response_file_extension]) ) - target_file_name = f"{dossier_id}/{file_id}.{target_file_extension}" - response_file_name = f"{dossier_id}/{file_id}.{response_file_extension}" + x_tenant_id = payload.get("X-TENANT-ID") - processing_kwargs = project(payload, self.allowed_args) + processing_kwargs = project(payload, self.allowed_processing_params) - return QueueMessagePayload( + return LegacyQueueMessagePayload( dossier_id=dossier_id, file_id=file_id, x_tenant_id=x_tenant_id, @@ -63,36 +101,59 @@ class QueueMessagePayloadParser: target_compression_type=target_compression_type, response_file_type=response_file_type, response_compression_type=response_compression_type, - target_file_name=target_file_name, - response_file_name=response_file_name, + target_file_path=target_file_path, + response_file_path=response_file_path, processing_kwargs=processing_kwargs, ) +def maybe_legacy_payload(payload: dict) -> bool: + return {"dossierId", "fileId", "targetFileExtension", "responseFileExtension"}.issubset(payload.keys()) + + def get_queue_message_payload_parser(config: Config) -> QueueMessagePayloadParser: file_extension_parser = make_file_extension_parser(config.allowed_file_types, config.allowed_compression_types) - return QueueMessagePayloadParser(file_extension_parser) + return QueueMessagePayloadParser(file_extension_parser, config.allowed_processing_parameters) -class QueueMessagePayloadFormatter: - @staticmethod - def format_service_processing_result_for_storage( - queue_message_payload: QueueMessagePayload, service_processing_result: Sized - ) -> dict: - """Format the results of a processing function with the QueueMessagePayload for the storage upload.""" - return { - "dossierId": queue_message_payload.dossier_id, - "fileId": queue_message_payload.file_id, - "targetFileExtension": queue_message_payload.target_file_extension, - "responseFileExtension": queue_message_payload.response_file_extension, - "data": service_processing_result, - } - - @staticmethod - def format_to_queue_message_response_body(queue_message_payload: QueueMessagePayload) -> dict: - """Format QueueMessagePayload for the AMPQ response after processing.""" - return {"dossierId": queue_message_payload.dossier_id, "fileId": queue_message_payload.file_id} +@singledispatch +def format_service_processing_result_for_storage(payload: QueueMessagePayload, result: Sized) -> dict: + raise NotImplementedError("Unsupported payload type") -def get_queue_message_payload_formatter() -> QueueMessagePayloadFormatter: - return QueueMessagePayloadFormatter() +@format_service_processing_result_for_storage.register(LegacyQueueMessagePayload) +def _(payload: LegacyQueueMessagePayload, result: Sized) -> dict: + return { + "dossierId": payload.dossier_id, + "fileId": payload.file_id, + "targetFileExtension": payload.target_file_extension, + "responseFileExtension": payload.response_file_extension, + "data": result, + } + + +@format_service_processing_result_for_storage.register(QueueMessagePayload) +def _(payload: QueueMessagePayload, result: Sized) -> dict: + return { + "targetFilePath": payload.target_file_path, + "responseFilePath": payload.response_file_path, + "data": result, + } + + +@singledispatch +def format_to_queue_message_response_body(queue_message_payload: QueueMessagePayload) -> dict: + raise NotImplementedError("Unsupported payload type") + + +@format_to_queue_message_response_body.register(LegacyQueueMessagePayload) +def _(queue_message_payload: LegacyQueueMessagePayload) -> dict: + return {"dossierId": queue_message_payload.dossier_id, "fileId": queue_message_payload.file_id} + + +@format_to_queue_message_response_body.register(QueueMessagePayload) +def _(queue_message_payload: QueueMessagePayload) -> dict: + return { + "targetFilePath": queue_message_payload.target_file_path, + "responseFilePath": queue_message_payload.response_file_path, + } diff --git a/pyinfra/payload_processing/processor.py b/pyinfra/payload_processing/processor.py index e4359f9..5f98d22 100644 --- a/pyinfra/payload_processing/processor.py +++ b/pyinfra/payload_processing/processor.py @@ -8,8 +8,9 @@ from pyinfra.payload_processing.monitor import get_monitor_from_config from pyinfra.payload_processing.payload import ( QueueMessagePayloadParser, get_queue_message_payload_parser, - QueueMessagePayloadFormatter, - get_queue_message_payload_formatter, + format_service_processing_result_for_storage, + format_to_queue_message_response_body, + QueueMessagePayload, ) from pyinfra.storage.storage import make_downloader, make_uploader from pyinfra.storage.storage_info import ( @@ -29,7 +30,6 @@ class PayloadProcessor: default_storage_info: StorageInfo, get_storage_info_from_tenant_id, payload_parser: QueueMessagePayloadParser, - payload_formatter: QueueMessagePayloadFormatter, data_processor: Callable, ): """Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps. @@ -39,14 +39,11 @@ class PayloadProcessor: x_tenant_id is not provided in the queue payload. get_storage_info_from_tenant_id: Callable to acquire storage info from a given tenant id. payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object - payload_formatter: Formatter for the storage upload result and the queue message response body data_processor: The analysis function to be called with the downloaded file NOTE: The result of the analysis function has to be an instance of `Sized`, e.g. a dict or a list to be able to upload it and to be able to monitor the processing time. """ self.parse_payload = payload_parser - self.format_result_for_storage = payload_formatter.format_service_processing_result_for_storage - self.format_to_queue_message_response_body = payload_formatter.format_to_queue_message_response_body self.process_data = data_processor self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id @@ -71,8 +68,11 @@ class PayloadProcessor: return self._process(queue_message_payload) def _process(self, queue_message_payload: dict) -> dict: - payload = self.parse_payload(queue_message_payload) - logger.info(f"Processing {asdict(payload)} ...") + logger.info(f"Processing Payload ...") + + payload: QueueMessagePayload = self.parse_payload(queue_message_payload) + + logger.debug(f"Payload: {asdict(payload)} ...") storage_info = self._get_storage_info(payload.x_tenant_id) storage = get_storage_from_storage_info(storage_info) @@ -84,15 +84,14 @@ class PayloadProcessor: upload_processing_result = make_uploader( storage, bucket, payload.response_file_type, payload.response_compression_type ) - format_result_for_storage = partial(self.format_result_for_storage, payload) - data = download_file_to_process(payload.target_file_name) + data = download_file_to_process(payload.target_file_path) result: List[dict] = self.process_data(data, **payload.processing_kwargs) - formatted_result = format_result_for_storage(result) + formatted_result = format_service_processing_result_for_storage(payload, result) - upload_processing_result(payload.response_file_name, formatted_result) + upload_processing_result(payload.response_file_path, formatted_result) - return self.format_to_queue_message_response_body(payload) + return format_to_queue_message_response_body(payload) def _get_storage_info(self, x_tenant_id=None): if x_tenant_id: @@ -118,7 +117,6 @@ def make_payload_processor(data_processor: Callable, config: Config = None) -> P ) monitor = get_monitor_from_config(config) payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config) - payload_formatter: QueueMessagePayloadFormatter = get_queue_message_payload_formatter() data_processor = monitor(data_processor) @@ -126,6 +124,5 @@ def make_payload_processor(data_processor: Callable, config: Config = None) -> P default_storage_info, get_storage_info_from_tenant_id, payload_parser, - payload_formatter, data_processor, ) diff --git a/tests/payload_parsing_test.py b/tests/payload_parsing_test.py index 303f301..a66fb61 100644 --- a/tests/payload_parsing_test.py +++ b/tests/payload_parsing_test.py @@ -1,15 +1,15 @@ import pytest from pyinfra.payload_processing.payload import ( - QueueMessagePayload, QueueMessagePayloadParser, + LegacyQueueMessagePayload, ) from pyinfra.utils.file_extension_parsing import make_file_extension_parser @pytest.fixture def expected_parsed_payload(x_tenant_id): - return QueueMessagePayload( + return LegacyQueueMessagePayload( dossier_id="test", file_id="test", x_tenant_id=x_tenant_id, @@ -19,8 +19,8 @@ def expected_parsed_payload(x_tenant_id): target_compression_type="gz", response_file_type="json", response_compression_type="gz", - target_file_name="test/test.json.gz", - response_file_name="test/test.json.gz", + target_file_path="test/test.json.gz", + response_file_path="test/test.json.gz", processing_kwargs={}, ) @@ -32,7 +32,7 @@ def file_extension_parser(allowed_file_types, allowed_compression_types): @pytest.fixture def payload_parser(file_extension_parser): - return QueueMessagePayloadParser(file_extension_parser) + return QueueMessagePayloadParser(file_extension_parser, allowed_processing_parameters=["operation"]) @pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])]) diff --git a/tests/payload_processor_test.py b/tests/payload_processor_test.py index 0c48ac4..b7a8b23 100644 --- a/tests/payload_processor_test.py +++ b/tests/payload_processor_test.py @@ -66,7 +66,9 @@ class TestPayloadProcessor: with pytest.raises(Exception): payload_processor(payload) - def test_prometheus_endpoint_is_available(self, test_storage_config, monitoring_enabled, storage_backend, x_tenant_id): + def test_prometheus_endpoint_is_available( + self, test_storage_config, monitoring_enabled, storage_backend, x_tenant_id + ): if monitoring_enabled: resp = requests.get( f"http://{test_storage_config.prometheus_host}:{test_storage_config.prometheus_port}/prometheus"