diff --git a/pyinfra/config.py b/pyinfra/config.py index fe2d72e..396e20d 100644 --- a/pyinfra/config.py +++ b/pyinfra/config.py @@ -92,6 +92,8 @@ class Config: self.allowed_file_types = ["json", "pdf"] self.allowed_compression_types = ["gz"] + self.allowed_processing_parameters = ["operation"] + # config for x-tenant-endpoint to receive storage connection information per tenant self.tenant_decryption_public_key = read_from_environment("TENANT_PUBLIC_KEY", "redaction") self.tenant_endpoint = read_from_environment("TENANT_ENDPOINT", "http://tenant-user-management:8081/internal-api/tenants") diff --git a/pyinfra/payload_processing/payload.py b/pyinfra/payload_processing/payload.py index ef5f394..5557d43 100644 --- a/pyinfra/payload_processing/payload.py +++ b/pyinfra/payload_processing/payload.py @@ -1,59 +1,97 @@ from dataclasses import dataclass +from functools import singledispatch +from funcy import project from itertools import chain from operator import itemgetter from typing import Union, Sized -from funcy import project - +from pyinfra import logger from pyinfra.config import Config from pyinfra.utils.file_extension_parsing import make_file_extension_parser @dataclass class QueueMessagePayload: - dossier_id: str - file_id: str - x_tenant_id: Union[str, None] + """Default one-to-one payload, where the message contains the absolute file paths for the target and response files, + that have to be acquired from the storage.""" - target_file_extension: str - response_file_extension: str + target_file_path: str + response_file_path: str target_file_type: Union[str, None] target_compression_type: Union[str, None] response_file_type: Union[str, None] response_compression_type: Union[str, None] - target_file_name: str - response_file_name: str + x_tenant_id: Union[str, None] processing_kwargs: dict +@dataclass +class LegacyQueueMessagePayload(QueueMessagePayload): + """Legacy one-to-one payload, where the message contains the dossier and file ids, and the file extensions that have + to be used to construct the absolute file paths for the target and response files, that have to be acquired from the + storage.""" + + dossier_id: str + file_id: str + + target_file_extension: str + response_file_extension: str + + class QueueMessagePayloadParser: - def __init__(self, file_extension_parser, allowed_processing_args=("operation",)): + def __init__(self, file_extension_parser, allowed_processing_parameters): self.parse_file_extensions = file_extension_parser - self.allowed_args = allowed_processing_args + self.allowed_processing_params = allowed_processing_parameters def __call__(self, payload: dict) -> QueueMessagePayload: - """Translate the queue message payload to the internal QueueMessagePayload object.""" - return self._parse_queue_message_payload(payload) + if maybe_legacy_payload(payload): + logger.debug("Legacy payload detected.") + return self._parse_legacy_queue_message_payload(payload) + else: + return self._parse_queue_message_payload(payload) def _parse_queue_message_payload(self, payload: dict) -> QueueMessagePayload: + target_file_path, response_file_path = itemgetter("targetFilePath", "responseFilePath")(payload) + + target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable( + map(self.parse_file_extensions, [target_file_path, response_file_path]) + ) + + x_tenant_id = payload.get("X-TENANT-ID") + + processing_kwargs = project(payload, self.allowed_processing_params) + + return QueueMessagePayload( + target_file_path=target_file_path, + response_file_path=response_file_path, + target_file_type=target_file_type, + target_compression_type=target_compression_type, + response_file_type=response_file_type, + response_compression_type=response_compression_type, + x_tenant_id=x_tenant_id, + processing_kwargs=processing_kwargs, + ) + + def _parse_legacy_queue_message_payload(self, payload: dict) -> LegacyQueueMessagePayload: dossier_id, file_id, target_file_extension, response_file_extension = itemgetter( "dossierId", "fileId", "targetFileExtension", "responseFileExtension" )(payload) - x_tenant_id = payload.get("X-TENANT-ID") + + target_file_path = f"{dossier_id}/{file_id}.{target_file_extension}" + response_file_path = f"{dossier_id}/{file_id}.{response_file_extension}" target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable( map(self.parse_file_extensions, [target_file_extension, response_file_extension]) ) - target_file_name = f"{dossier_id}/{file_id}.{target_file_extension}" - response_file_name = f"{dossier_id}/{file_id}.{response_file_extension}" + x_tenant_id = payload.get("X-TENANT-ID") - processing_kwargs = project(payload, self.allowed_args) + processing_kwargs = project(payload, self.allowed_processing_params) - return QueueMessagePayload( + return LegacyQueueMessagePayload( dossier_id=dossier_id, file_id=file_id, x_tenant_id=x_tenant_id, @@ -63,36 +101,59 @@ class QueueMessagePayloadParser: target_compression_type=target_compression_type, response_file_type=response_file_type, response_compression_type=response_compression_type, - target_file_name=target_file_name, - response_file_name=response_file_name, + target_file_path=target_file_path, + response_file_path=response_file_path, processing_kwargs=processing_kwargs, ) +def maybe_legacy_payload(payload: dict) -> bool: + return {"dossierId", "fileId", "targetFileExtension", "responseFileExtension"}.issubset(payload.keys()) + + def get_queue_message_payload_parser(config: Config) -> QueueMessagePayloadParser: file_extension_parser = make_file_extension_parser(config.allowed_file_types, config.allowed_compression_types) - return QueueMessagePayloadParser(file_extension_parser) + return QueueMessagePayloadParser(file_extension_parser, config.allowed_processing_parameters) -class QueueMessagePayloadFormatter: - @staticmethod - def format_service_processing_result_for_storage( - queue_message_payload: QueueMessagePayload, service_processing_result: Sized - ) -> dict: - """Format the results of a processing function with the QueueMessagePayload for the storage upload.""" - return { - "dossierId": queue_message_payload.dossier_id, - "fileId": queue_message_payload.file_id, - "targetFileExtension": queue_message_payload.target_file_extension, - "responseFileExtension": queue_message_payload.response_file_extension, - "data": service_processing_result, - } - - @staticmethod - def format_to_queue_message_response_body(queue_message_payload: QueueMessagePayload) -> dict: - """Format QueueMessagePayload for the AMPQ response after processing.""" - return {"dossierId": queue_message_payload.dossier_id, "fileId": queue_message_payload.file_id} +@singledispatch +def format_service_processing_result_for_storage(payload: QueueMessagePayload, result: Sized) -> dict: + raise NotImplementedError("Unsupported payload type") -def get_queue_message_payload_formatter() -> QueueMessagePayloadFormatter: - return QueueMessagePayloadFormatter() +@format_service_processing_result_for_storage.register(LegacyQueueMessagePayload) +def _(payload: LegacyQueueMessagePayload, result: Sized) -> dict: + return { + "dossierId": payload.dossier_id, + "fileId": payload.file_id, + "targetFileExtension": payload.target_file_extension, + "responseFileExtension": payload.response_file_extension, + "data": result, + } + + +@format_service_processing_result_for_storage.register(QueueMessagePayload) +def _(payload: QueueMessagePayload, result: Sized) -> dict: + return { + "targetFilePath": payload.target_file_path, + "responseFilePath": payload.response_file_path, + "data": result, + } + + +@singledispatch +def format_to_queue_message_response_body(queue_message_payload: QueueMessagePayload) -> dict: + raise NotImplementedError("Unsupported payload type") + + +@format_to_queue_message_response_body.register(LegacyQueueMessagePayload) +def _(queue_message_payload: LegacyQueueMessagePayload) -> dict: + return {"dossierId": queue_message_payload.dossier_id, "fileId": queue_message_payload.file_id} + + +@format_to_queue_message_response_body.register(QueueMessagePayload) +def _(queue_message_payload: QueueMessagePayload) -> dict: + return { + "targetFilePath": queue_message_payload.target_file_path, + "responseFilePath": queue_message_payload.response_file_path, + } diff --git a/pyinfra/payload_processing/processor.py b/pyinfra/payload_processing/processor.py index e4359f9..5f98d22 100644 --- a/pyinfra/payload_processing/processor.py +++ b/pyinfra/payload_processing/processor.py @@ -8,8 +8,9 @@ from pyinfra.payload_processing.monitor import get_monitor_from_config from pyinfra.payload_processing.payload import ( QueueMessagePayloadParser, get_queue_message_payload_parser, - QueueMessagePayloadFormatter, - get_queue_message_payload_formatter, + format_service_processing_result_for_storage, + format_to_queue_message_response_body, + QueueMessagePayload, ) from pyinfra.storage.storage import make_downloader, make_uploader from pyinfra.storage.storage_info import ( @@ -29,7 +30,6 @@ class PayloadProcessor: default_storage_info: StorageInfo, get_storage_info_from_tenant_id, payload_parser: QueueMessagePayloadParser, - payload_formatter: QueueMessagePayloadFormatter, data_processor: Callable, ): """Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps. @@ -39,14 +39,11 @@ class PayloadProcessor: x_tenant_id is not provided in the queue payload. get_storage_info_from_tenant_id: Callable to acquire storage info from a given tenant id. payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object - payload_formatter: Formatter for the storage upload result and the queue message response body data_processor: The analysis function to be called with the downloaded file NOTE: The result of the analysis function has to be an instance of `Sized`, e.g. a dict or a list to be able to upload it and to be able to monitor the processing time. """ self.parse_payload = payload_parser - self.format_result_for_storage = payload_formatter.format_service_processing_result_for_storage - self.format_to_queue_message_response_body = payload_formatter.format_to_queue_message_response_body self.process_data = data_processor self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id @@ -71,8 +68,11 @@ class PayloadProcessor: return self._process(queue_message_payload) def _process(self, queue_message_payload: dict) -> dict: - payload = self.parse_payload(queue_message_payload) - logger.info(f"Processing {asdict(payload)} ...") + logger.info(f"Processing Payload ...") + + payload: QueueMessagePayload = self.parse_payload(queue_message_payload) + + logger.debug(f"Payload: {asdict(payload)} ...") storage_info = self._get_storage_info(payload.x_tenant_id) storage = get_storage_from_storage_info(storage_info) @@ -84,15 +84,14 @@ class PayloadProcessor: upload_processing_result = make_uploader( storage, bucket, payload.response_file_type, payload.response_compression_type ) - format_result_for_storage = partial(self.format_result_for_storage, payload) - data = download_file_to_process(payload.target_file_name) + data = download_file_to_process(payload.target_file_path) result: List[dict] = self.process_data(data, **payload.processing_kwargs) - formatted_result = format_result_for_storage(result) + formatted_result = format_service_processing_result_for_storage(payload, result) - upload_processing_result(payload.response_file_name, formatted_result) + upload_processing_result(payload.response_file_path, formatted_result) - return self.format_to_queue_message_response_body(payload) + return format_to_queue_message_response_body(payload) def _get_storage_info(self, x_tenant_id=None): if x_tenant_id: @@ -118,7 +117,6 @@ def make_payload_processor(data_processor: Callable, config: Config = None) -> P ) monitor = get_monitor_from_config(config) payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config) - payload_formatter: QueueMessagePayloadFormatter = get_queue_message_payload_formatter() data_processor = monitor(data_processor) @@ -126,6 +124,5 @@ def make_payload_processor(data_processor: Callable, config: Config = None) -> P default_storage_info, get_storage_info_from_tenant_id, payload_parser, - payload_formatter, data_processor, ) diff --git a/tests/payload_parsing_test.py b/tests/payload_parsing_test.py index 303f301..a66fb61 100644 --- a/tests/payload_parsing_test.py +++ b/tests/payload_parsing_test.py @@ -1,15 +1,15 @@ import pytest from pyinfra.payload_processing.payload import ( - QueueMessagePayload, QueueMessagePayloadParser, + LegacyQueueMessagePayload, ) from pyinfra.utils.file_extension_parsing import make_file_extension_parser @pytest.fixture def expected_parsed_payload(x_tenant_id): - return QueueMessagePayload( + return LegacyQueueMessagePayload( dossier_id="test", file_id="test", x_tenant_id=x_tenant_id, @@ -19,8 +19,8 @@ def expected_parsed_payload(x_tenant_id): target_compression_type="gz", response_file_type="json", response_compression_type="gz", - target_file_name="test/test.json.gz", - response_file_name="test/test.json.gz", + target_file_path="test/test.json.gz", + response_file_path="test/test.json.gz", processing_kwargs={}, ) @@ -32,7 +32,7 @@ def file_extension_parser(allowed_file_types, allowed_compression_types): @pytest.fixture def payload_parser(file_extension_parser): - return QueueMessagePayloadParser(file_extension_parser) + return QueueMessagePayloadParser(file_extension_parser, allowed_processing_parameters=["operation"]) @pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])]) diff --git a/tests/payload_processor_test.py b/tests/payload_processor_test.py index 0c48ac4..b7a8b23 100644 --- a/tests/payload_processor_test.py +++ b/tests/payload_processor_test.py @@ -66,7 +66,9 @@ class TestPayloadProcessor: with pytest.raises(Exception): payload_processor(payload) - def test_prometheus_endpoint_is_available(self, test_storage_config, monitoring_enabled, storage_backend, x_tenant_id): + def test_prometheus_endpoint_is_available( + self, test_storage_config, monitoring_enabled, storage_backend, x_tenant_id + ): if monitoring_enabled: resp = requests.get( f"http://{test_storage_config.prometheus_host}:{test_storage_config.prometheus_port}/prometheus"