Refactor payload processing logic

Streamlines payload processor class by encapsulating closely dependent
logic, to improve readability and maintainability.
This commit is contained in:
Julius Unverfehrt 2023-08-18 12:39:11 +02:00
parent 48d74b4307
commit ef916ee790
2 changed files with 58 additions and 45 deletions

View File

@ -1,6 +1,5 @@
import logging
from dataclasses import asdict
from functools import partial
from typing import Callable, List
from pyinfra.config import get_config, Config
@ -13,12 +12,7 @@ from pyinfra.payload_processing.payload import (
QueueMessagePayload,
)
from pyinfra.storage.storage import make_downloader, make_uploader
from pyinfra.storage.storage_info import (
get_storage_info_from_config,
get_storage_info_from_endpoint,
StorageInfo,
get_storage_from_storage_info,
)
from pyinfra.storage.storage_manager import StorageManager
logger = logging.getLogger()
logger.setLevel(get_config().logging_level_root)
@ -27,28 +21,23 @@ logger.setLevel(get_config().logging_level_root)
class PayloadProcessor:
def __init__(
self,
default_storage_info: StorageInfo,
get_storage_info_from_tenant_id,
storage_manager: StorageManager,
payload_parser: QueueMessagePayloadParser,
data_processor: Callable,
):
"""Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps.
Args:
default_storage_info: The default storage info used to create the storage connection. This is only used if
x_tenant_id is not provided in the queue payload.
get_storage_info_from_tenant_id: Callable to acquire storage info from a given tenant id.
storage_manager: Storage manager that connects to the storage, using the tenant id if provided
payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object
data_processor: The analysis function to be called with the downloaded file
NOTE: The result of the analysis function has to be an instance of `Sized`, e.g. a dict or a list to be
able to upload it and to be able to monitor the processing time.
"""
self.parse_payload = payload_parser
self.connect_storage = storage_manager
self.process_data = data_processor
self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id
self.default_storage_info = default_storage_info
def __call__(self, queue_message_payload: dict) -> dict:
"""Processes a queue message payload.
@ -60,29 +49,29 @@ class PayloadProcessor:
Args:
queue_message_payload: The payload of a queue message. The payload is expected to be a dict with the
following keys: dossierId, fileId, targetFileExtension, responseFileExtension
following keys:
targetFilePath, responseFilePath
OR
dossierId, fileId, targetFileExtension, responseFileExtension
Returns:
The payload for a response queue message. The payload is a dict with the following keys: dossierId, fileId
The payload for a response queue message, containing only the request payload.
"""
return self._process(queue_message_payload)
def _process(self, queue_message_payload: dict) -> dict:
logger.info(f"Processing Payload ...")
payload: QueueMessagePayload = self.parse_payload(queue_message_payload)
logger.debug(f"Payload: {asdict(payload)} ...")
logger.info(f"Processing {payload.__class__.__name__} ...")
logger.debug(f"Payload contents: {asdict(payload)} ...")
storage_info = self._get_storage_info(payload.x_tenant_id)
storage = get_storage_from_storage_info(storage_info)
bucket = storage_info.bucket_name
storage, storage_info = self.connect_storage(payload.x_tenant_id)
download_file_to_process = make_downloader(
storage, bucket, payload.target_file_type, payload.target_compression_type
storage, storage_info.bucket, payload.target_file_type, payload.target_compression_type
)
upload_processing_result = make_uploader(
storage, bucket, payload.response_file_type, payload.response_compression_type
storage, storage_info.bucket, payload.response_file_type, payload.response_compression_type
)
data = download_file_to_process(payload.target_file_path)
@ -93,36 +82,19 @@ class PayloadProcessor:
return format_to_queue_message_response_body(payload)
def _get_storage_info(self, x_tenant_id=None):
if x_tenant_id:
storage_info = self.get_storage_info_from_tenant_id(x_tenant_id)
logger.info(f"Received {storage_info.__class__.__name__} for {x_tenant_id} from endpoint.")
logger.debug(f"{asdict(storage_info)}")
else:
storage_info = self.default_storage_info
logger.info(f"Using local default {storage_info.__class__.__name__} for {x_tenant_id}.")
logger.debug(f"{asdict(storage_info)}")
return storage_info
def make_payload_processor(data_processor: Callable, config: Config = None) -> PayloadProcessor:
"""Produces payload processor for queue manager."""
"""Creates a payload processor."""
config = config or get_config()
default_storage_info: StorageInfo = get_storage_info_from_config(config)
get_storage_info_from_tenant_id = partial(
get_storage_info_from_endpoint,
config.tenant_decryption_public_key,
config.tenant_endpoint,
)
storage_manager = StorageManager(config)
monitor = get_monitor_from_config(config)
payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config)
data_processor = monitor(data_processor)
return PayloadProcessor(
default_storage_info,
get_storage_info_from_tenant_id,
storage_manager,
payload_parser,
data_processor,
)

View File

@ -0,0 +1,41 @@
from dataclasses import asdict
from functools import partial, lru_cache
from typing import Tuple
from pyinfra import logger
from pyinfra.config import Config
from pyinfra.storage.storage_info import get_storage_info_from_config, get_storage_info_from_endpoint, StorageInfo
from pyinfra.storage.storages.interface import Storage
class StorageManager:
def __init__(self, config: Config):
self.config = config
self.default_storage_info: StorageInfo = get_storage_info_from_config(config)
self.get_storage_info_from_tenant_id = partial(
get_storage_info_from_endpoint,
config.tenant_decryption_public_key,
config.tenant_endpoint,
)
def __call__(self, *args, **kwargs):
return self.connect(*args, **kwargs)
@lru_cache(maxsize=32)
def connect(self, x_tenant_id=None) -> Tuple[Storage, StorageInfo]:
storage_info = self._get_storage_info(x_tenant_id)
storage_connection = storage_info.get_storage()
return storage_connection, storage_info
def _get_storage_info(self, x_tenant_id=None):
if x_tenant_id:
storage_info = self.get_storage_info_from_tenant_id(x_tenant_id)
logger.debug(f"Received {storage_info.__class__.__name__} for {x_tenant_id} from endpoint.")
logger.trace(f"{asdict(storage_info)}")
else:
storage_info = self.default_storage_info
logger.debug(f"Using local default {storage_info.__class__.__name__} for {x_tenant_id}.")
logger.trace(f"{asdict(storage_info)}")
return storage_info