Refactor storage provider & payload parser

Applies strategy pattern to payload parsing logic to improve
maintainability and testability.
Renames storage manager to storage provider.
This commit is contained in:
Julius Unverfehrt 2023-08-21 15:11:08 +02:00
parent 294688ea66
commit e580a66347
3 changed files with 109 additions and 82 deletions

View File

@ -1,11 +1,10 @@
from dataclasses import dataclass
from functools import singledispatch
from funcy import project
from functools import singledispatch, partial
from funcy import project, complement
from itertools import chain
from operator import itemgetter
from typing import Union, Sized
from typing import Union, Sized, Callable, List
from kn_utils.logging import logger
from pyinfra.config import Config
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
@ -42,78 +41,105 @@ class LegacyQueueMessagePayload(QueueMessagePayload):
class QueueMessagePayloadParser:
def __init__(self, file_extension_parser, allowed_processing_parameters):
self.parse_file_extensions = file_extension_parser
self.allowed_processing_params = allowed_processing_parameters
def __init__(self, payload_matcher2parse_strategy: dict):
self.payload_matcher2parse_strategy = payload_matcher2parse_strategy
def __call__(self, payload: dict) -> QueueMessagePayload:
if maybe_legacy_payload(payload):
logger.debug("Legacy payload detected.")
return self._parse_legacy_queue_message_payload(payload)
else:
return self._parse_queue_message_payload(payload)
def _parse_queue_message_payload(self, payload: dict) -> QueueMessagePayload:
target_file_path, response_file_path = itemgetter("targetFilePath", "responseFilePath")(payload)
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
map(self.parse_file_extensions, [target_file_path, response_file_path])
)
x_tenant_id = payload.get("X-TENANT-ID")
processing_kwargs = project(payload, self.allowed_processing_params)
return QueueMessagePayload(
target_file_path=target_file_path,
response_file_path=response_file_path,
target_file_type=target_file_type,
target_compression_type=target_compression_type,
response_file_type=response_file_type,
response_compression_type=response_compression_type,
x_tenant_id=x_tenant_id,
processing_kwargs=processing_kwargs,
)
def _parse_legacy_queue_message_payload(self, payload: dict) -> LegacyQueueMessagePayload:
dossier_id, file_id, target_file_extension, response_file_extension = itemgetter(
"dossierId", "fileId", "targetFileExtension", "responseFileExtension"
)(payload)
target_file_path = f"{dossier_id}/{file_id}.{target_file_extension}"
response_file_path = f"{dossier_id}/{file_id}.{response_file_extension}"
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
map(self.parse_file_extensions, [target_file_extension, response_file_extension])
)
x_tenant_id = payload.get("X-TENANT-ID")
processing_kwargs = project(payload, self.allowed_processing_params)
return LegacyQueueMessagePayload(
dossier_id=dossier_id,
file_id=file_id,
x_tenant_id=x_tenant_id,
target_file_extension=target_file_extension,
response_file_extension=response_file_extension,
target_file_type=target_file_type,
target_compression_type=target_compression_type,
response_file_type=response_file_type,
response_compression_type=response_compression_type,
target_file_path=target_file_path,
response_file_path=response_file_path,
processing_kwargs=processing_kwargs,
)
def maybe_legacy_payload(payload: dict) -> bool:
return {"dossierId", "fileId", "targetFileExtension", "responseFileExtension"}.issubset(payload.keys())
for payload_matcher, parse_strategy in self.payload_matcher2parse_strategy.items():
if payload_matcher(payload):
return parse_strategy(payload)
def get_queue_message_payload_parser(config: Config) -> QueueMessagePayloadParser:
file_extension_parser = make_file_extension_parser(config.allowed_file_types, config.allowed_compression_types)
return QueueMessagePayloadParser(file_extension_parser, config.allowed_processing_parameters)
payload_matcher2parse_strategy = get_payload_matcher2parse_strategy(
file_extension_parser, config.allowed_processing_parameters
)
return QueueMessagePayloadParser(payload_matcher2parse_strategy)
def get_payload_matcher2parse_strategy(parse_file_extensions: Callable, allowed_processing_parameters: List[str]):
return {
is_legacy_payload: partial(
parse_legacy_queue_message_payload,
parse_file_extensions=parse_file_extensions,
allowed_processing_parameters=allowed_processing_parameters,
),
complement(is_legacy_payload): partial(
parse_queue_message_payload,
parse_file_extensions=parse_file_extensions,
allowed_processing_parameters=allowed_processing_parameters,
),
}
def is_legacy_payload(payload: dict) -> bool:
return {"dossierId", "fileId", "targetFileExtension", "responseFileExtension"}.issubset(payload.keys())
def parse_queue_message_payload(
payload: dict,
parse_file_extensions: Callable,
allowed_processing_parameters: List[str],
) -> QueueMessagePayload:
target_file_path, response_file_path = itemgetter("targetFilePath", "responseFilePath")(payload)
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
map(parse_file_extensions, [target_file_path, response_file_path])
)
x_tenant_id = payload.get("X-TENANT-ID")
processing_kwargs = project(payload, allowed_processing_parameters)
return QueueMessagePayload(
target_file_path=target_file_path,
response_file_path=response_file_path,
target_file_type=target_file_type,
target_compression_type=target_compression_type,
response_file_type=response_file_type,
response_compression_type=response_compression_type,
x_tenant_id=x_tenant_id,
processing_kwargs=processing_kwargs,
)
def parse_legacy_queue_message_payload(
payload: dict,
parse_file_extensions: Callable,
allowed_processing_parameters: List[str],
) -> LegacyQueueMessagePayload:
dossier_id, file_id, target_file_extension, response_file_extension = itemgetter(
"dossierId", "fileId", "targetFileExtension", "responseFileExtension"
)(payload)
target_file_path = f"{dossier_id}/{file_id}.{target_file_extension}"
response_file_path = f"{dossier_id}/{file_id}.{response_file_extension}"
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
map(parse_file_extensions, [target_file_extension, response_file_extension])
)
x_tenant_id = payload.get("X-TENANT-ID")
processing_kwargs = project(payload, allowed_processing_parameters)
return LegacyQueueMessagePayload(
dossier_id=dossier_id,
file_id=file_id,
x_tenant_id=x_tenant_id,
target_file_extension=target_file_extension,
response_file_extension=response_file_extension,
target_file_type=target_file_type,
target_compression_type=target_compression_type,
response_file_type=response_file_type,
response_compression_type=response_compression_type,
target_file_path=target_file_path,
response_file_path=response_file_path,
processing_kwargs=processing_kwargs,
)
@singledispatch

View File

@ -12,7 +12,7 @@ from pyinfra.payload_processing.payload import (
QueueMessagePayload,
)
from pyinfra.storage.storage import make_downloader, make_uploader
from pyinfra.storage.storage_manager import StorageManager
from pyinfra.storage.storage_provider import StorageProvider
logger = getLogger()
logger.setLevel(get_config().logging_level_root)
@ -21,21 +21,21 @@ logger.setLevel(get_config().logging_level_root)
class PayloadProcessor:
def __init__(
self,
storage_manager: StorageManager,
storage_provider: StorageProvider,
payload_parser: QueueMessagePayloadParser,
data_processor: Callable,
):
"""Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps.
Args:
storage_manager: Storage manager that connects to the storage, using the tenant id if provided
storage_provider: Storage manager that connects to the storage, using the tenant id if provided
payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object
data_processor: The analysis function to be called with the downloaded file
NOTE: The result of the analysis function has to be an instance of `Sized`, e.g. a dict or a list to be
able to upload it and to be able to monitor the processing time.
"""
self.parse_payload = payload_parser
self.connect_storage = storage_manager
self.provide_storage = storage_provider
self.process_data = data_processor
def __call__(self, queue_message_payload: dict) -> dict:
@ -65,13 +65,13 @@ class PayloadProcessor:
logger.info(f"Processing {payload.__class__.__name__} ...")
logger.debug(f"Payload contents: {asdict(payload)} ...")
storage, storage_info = self.connect_storage(payload.x_tenant_id)
storage, storage_info = self.provide_storage(payload.x_tenant_id)
download_file_to_process = make_downloader(
storage, storage_info.bucket, payload.target_file_type, payload.target_compression_type
storage, storage_info.bucket_name, payload.target_file_type, payload.target_compression_type
)
upload_processing_result = make_uploader(
storage, storage_info.bucket, payload.response_file_type, payload.response_compression_type
storage, storage_info.bucket_name, payload.response_file_type, payload.response_compression_type
)
data = download_file_to_process(payload.target_file_path)
@ -87,14 +87,14 @@ def make_payload_processor(data_processor: Callable, config: Config = None) -> P
"""Creates a payload processor."""
config = config or get_config()
storage_manager = StorageManager(config)
storage_provider = StorageProvider(config)
monitor = get_monitor_from_config(config)
payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config)
data_processor = monitor(data_processor)
return PayloadProcessor(
storage_manager,
storage_provider,
payload_parser,
data_processor,
)

View File

@ -4,11 +4,12 @@ from kn_utils.logging import logger
from typing import Tuple
from pyinfra.config import Config
from pyinfra.storage.storage_info import get_storage_info_from_config, get_storage_info_from_endpoint, StorageInfo
from pyinfra.storage.storage_info import get_storage_info_from_config, get_storage_info_from_endpoint, StorageInfo, \
get_storage_from_storage_info
from pyinfra.storage.storages.interface import Storage
class StorageManager:
class StorageProvider:
def __init__(self, config: Config):
self.config = config
self.default_storage_info: StorageInfo = get_storage_info_from_config(config)
@ -25,7 +26,7 @@ class StorageManager:
@lru_cache(maxsize=32)
def connect(self, x_tenant_id=None) -> Tuple[Storage, StorageInfo]:
storage_info = self._get_storage_info(x_tenant_id)
storage_connection = storage_info.get_storage()
storage_connection = get_storage_from_storage_info(storage_info)
return storage_connection, storage_info
def _get_storage_info(self, x_tenant_id=None):