Refactor storage provider & payload parser
Applies strategy pattern to payload parsing logic to improve maintainability and testability. Renames storage manager to storage provider.
This commit is contained in:
parent
294688ea66
commit
e580a66347
@ -1,11 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import singledispatch
|
||||
from funcy import project
|
||||
from functools import singledispatch, partial
|
||||
from funcy import project, complement
|
||||
from itertools import chain
|
||||
from operator import itemgetter
|
||||
from typing import Union, Sized
|
||||
from typing import Union, Sized, Callable, List
|
||||
|
||||
from kn_utils.logging import logger
|
||||
from pyinfra.config import Config
|
||||
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
|
||||
|
||||
@ -42,78 +41,105 @@ class LegacyQueueMessagePayload(QueueMessagePayload):
|
||||
|
||||
|
||||
class QueueMessagePayloadParser:
|
||||
def __init__(self, file_extension_parser, allowed_processing_parameters):
|
||||
self.parse_file_extensions = file_extension_parser
|
||||
self.allowed_processing_params = allowed_processing_parameters
|
||||
def __init__(self, payload_matcher2parse_strategy: dict):
|
||||
self.payload_matcher2parse_strategy = payload_matcher2parse_strategy
|
||||
|
||||
def __call__(self, payload: dict) -> QueueMessagePayload:
|
||||
if maybe_legacy_payload(payload):
|
||||
logger.debug("Legacy payload detected.")
|
||||
return self._parse_legacy_queue_message_payload(payload)
|
||||
else:
|
||||
return self._parse_queue_message_payload(payload)
|
||||
|
||||
def _parse_queue_message_payload(self, payload: dict) -> QueueMessagePayload:
|
||||
target_file_path, response_file_path = itemgetter("targetFilePath", "responseFilePath")(payload)
|
||||
|
||||
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
|
||||
map(self.parse_file_extensions, [target_file_path, response_file_path])
|
||||
)
|
||||
|
||||
x_tenant_id = payload.get("X-TENANT-ID")
|
||||
|
||||
processing_kwargs = project(payload, self.allowed_processing_params)
|
||||
|
||||
return QueueMessagePayload(
|
||||
target_file_path=target_file_path,
|
||||
response_file_path=response_file_path,
|
||||
target_file_type=target_file_type,
|
||||
target_compression_type=target_compression_type,
|
||||
response_file_type=response_file_type,
|
||||
response_compression_type=response_compression_type,
|
||||
x_tenant_id=x_tenant_id,
|
||||
processing_kwargs=processing_kwargs,
|
||||
)
|
||||
|
||||
def _parse_legacy_queue_message_payload(self, payload: dict) -> LegacyQueueMessagePayload:
|
||||
dossier_id, file_id, target_file_extension, response_file_extension = itemgetter(
|
||||
"dossierId", "fileId", "targetFileExtension", "responseFileExtension"
|
||||
)(payload)
|
||||
|
||||
target_file_path = f"{dossier_id}/{file_id}.{target_file_extension}"
|
||||
response_file_path = f"{dossier_id}/{file_id}.{response_file_extension}"
|
||||
|
||||
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
|
||||
map(self.parse_file_extensions, [target_file_extension, response_file_extension])
|
||||
)
|
||||
|
||||
x_tenant_id = payload.get("X-TENANT-ID")
|
||||
|
||||
processing_kwargs = project(payload, self.allowed_processing_params)
|
||||
|
||||
return LegacyQueueMessagePayload(
|
||||
dossier_id=dossier_id,
|
||||
file_id=file_id,
|
||||
x_tenant_id=x_tenant_id,
|
||||
target_file_extension=target_file_extension,
|
||||
response_file_extension=response_file_extension,
|
||||
target_file_type=target_file_type,
|
||||
target_compression_type=target_compression_type,
|
||||
response_file_type=response_file_type,
|
||||
response_compression_type=response_compression_type,
|
||||
target_file_path=target_file_path,
|
||||
response_file_path=response_file_path,
|
||||
processing_kwargs=processing_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def maybe_legacy_payload(payload: dict) -> bool:
|
||||
return {"dossierId", "fileId", "targetFileExtension", "responseFileExtension"}.issubset(payload.keys())
|
||||
for payload_matcher, parse_strategy in self.payload_matcher2parse_strategy.items():
|
||||
if payload_matcher(payload):
|
||||
return parse_strategy(payload)
|
||||
|
||||
|
||||
def get_queue_message_payload_parser(config: Config) -> QueueMessagePayloadParser:
|
||||
file_extension_parser = make_file_extension_parser(config.allowed_file_types, config.allowed_compression_types)
|
||||
return QueueMessagePayloadParser(file_extension_parser, config.allowed_processing_parameters)
|
||||
|
||||
payload_matcher2parse_strategy = get_payload_matcher2parse_strategy(
|
||||
file_extension_parser, config.allowed_processing_parameters
|
||||
)
|
||||
|
||||
return QueueMessagePayloadParser(payload_matcher2parse_strategy)
|
||||
|
||||
|
||||
def get_payload_matcher2parse_strategy(parse_file_extensions: Callable, allowed_processing_parameters: List[str]):
|
||||
return {
|
||||
is_legacy_payload: partial(
|
||||
parse_legacy_queue_message_payload,
|
||||
parse_file_extensions=parse_file_extensions,
|
||||
allowed_processing_parameters=allowed_processing_parameters,
|
||||
),
|
||||
complement(is_legacy_payload): partial(
|
||||
parse_queue_message_payload,
|
||||
parse_file_extensions=parse_file_extensions,
|
||||
allowed_processing_parameters=allowed_processing_parameters,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def is_legacy_payload(payload: dict) -> bool:
|
||||
return {"dossierId", "fileId", "targetFileExtension", "responseFileExtension"}.issubset(payload.keys())
|
||||
|
||||
|
||||
def parse_queue_message_payload(
|
||||
payload: dict,
|
||||
parse_file_extensions: Callable,
|
||||
allowed_processing_parameters: List[str],
|
||||
) -> QueueMessagePayload:
|
||||
target_file_path, response_file_path = itemgetter("targetFilePath", "responseFilePath")(payload)
|
||||
|
||||
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
|
||||
map(parse_file_extensions, [target_file_path, response_file_path])
|
||||
)
|
||||
|
||||
x_tenant_id = payload.get("X-TENANT-ID")
|
||||
|
||||
processing_kwargs = project(payload, allowed_processing_parameters)
|
||||
|
||||
return QueueMessagePayload(
|
||||
target_file_path=target_file_path,
|
||||
response_file_path=response_file_path,
|
||||
target_file_type=target_file_type,
|
||||
target_compression_type=target_compression_type,
|
||||
response_file_type=response_file_type,
|
||||
response_compression_type=response_compression_type,
|
||||
x_tenant_id=x_tenant_id,
|
||||
processing_kwargs=processing_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def parse_legacy_queue_message_payload(
|
||||
payload: dict,
|
||||
parse_file_extensions: Callable,
|
||||
allowed_processing_parameters: List[str],
|
||||
) -> LegacyQueueMessagePayload:
|
||||
dossier_id, file_id, target_file_extension, response_file_extension = itemgetter(
|
||||
"dossierId", "fileId", "targetFileExtension", "responseFileExtension"
|
||||
)(payload)
|
||||
|
||||
target_file_path = f"{dossier_id}/{file_id}.{target_file_extension}"
|
||||
response_file_path = f"{dossier_id}/{file_id}.{response_file_extension}"
|
||||
|
||||
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
|
||||
map(parse_file_extensions, [target_file_extension, response_file_extension])
|
||||
)
|
||||
|
||||
x_tenant_id = payload.get("X-TENANT-ID")
|
||||
|
||||
processing_kwargs = project(payload, allowed_processing_parameters)
|
||||
|
||||
return LegacyQueueMessagePayload(
|
||||
dossier_id=dossier_id,
|
||||
file_id=file_id,
|
||||
x_tenant_id=x_tenant_id,
|
||||
target_file_extension=target_file_extension,
|
||||
response_file_extension=response_file_extension,
|
||||
target_file_type=target_file_type,
|
||||
target_compression_type=target_compression_type,
|
||||
response_file_type=response_file_type,
|
||||
response_compression_type=response_compression_type,
|
||||
target_file_path=target_file_path,
|
||||
response_file_path=response_file_path,
|
||||
processing_kwargs=processing_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@singledispatch
|
||||
|
||||
@ -12,7 +12,7 @@ from pyinfra.payload_processing.payload import (
|
||||
QueueMessagePayload,
|
||||
)
|
||||
from pyinfra.storage.storage import make_downloader, make_uploader
|
||||
from pyinfra.storage.storage_manager import StorageManager
|
||||
from pyinfra.storage.storage_provider import StorageProvider
|
||||
|
||||
logger = getLogger()
|
||||
logger.setLevel(get_config().logging_level_root)
|
||||
@ -21,21 +21,21 @@ logger.setLevel(get_config().logging_level_root)
|
||||
class PayloadProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
storage_manager: StorageManager,
|
||||
storage_provider: StorageProvider,
|
||||
payload_parser: QueueMessagePayloadParser,
|
||||
data_processor: Callable,
|
||||
):
|
||||
"""Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps.
|
||||
|
||||
Args:
|
||||
storage_manager: Storage manager that connects to the storage, using the tenant id if provided
|
||||
storage_provider: Storage manager that connects to the storage, using the tenant id if provided
|
||||
payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object
|
||||
data_processor: The analysis function to be called with the downloaded file
|
||||
NOTE: The result of the analysis function has to be an instance of `Sized`, e.g. a dict or a list to be
|
||||
able to upload it and to be able to monitor the processing time.
|
||||
"""
|
||||
self.parse_payload = payload_parser
|
||||
self.connect_storage = storage_manager
|
||||
self.provide_storage = storage_provider
|
||||
self.process_data = data_processor
|
||||
|
||||
def __call__(self, queue_message_payload: dict) -> dict:
|
||||
@ -65,13 +65,13 @@ class PayloadProcessor:
|
||||
logger.info(f"Processing {payload.__class__.__name__} ...")
|
||||
logger.debug(f"Payload contents: {asdict(payload)} ...")
|
||||
|
||||
storage, storage_info = self.connect_storage(payload.x_tenant_id)
|
||||
storage, storage_info = self.provide_storage(payload.x_tenant_id)
|
||||
|
||||
download_file_to_process = make_downloader(
|
||||
storage, storage_info.bucket, payload.target_file_type, payload.target_compression_type
|
||||
storage, storage_info.bucket_name, payload.target_file_type, payload.target_compression_type
|
||||
)
|
||||
upload_processing_result = make_uploader(
|
||||
storage, storage_info.bucket, payload.response_file_type, payload.response_compression_type
|
||||
storage, storage_info.bucket_name, payload.response_file_type, payload.response_compression_type
|
||||
)
|
||||
|
||||
data = download_file_to_process(payload.target_file_path)
|
||||
@ -87,14 +87,14 @@ def make_payload_processor(data_processor: Callable, config: Config = None) -> P
|
||||
"""Creates a payload processor."""
|
||||
config = config or get_config()
|
||||
|
||||
storage_manager = StorageManager(config)
|
||||
storage_provider = StorageProvider(config)
|
||||
monitor = get_monitor_from_config(config)
|
||||
payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config)
|
||||
|
||||
data_processor = monitor(data_processor)
|
||||
|
||||
return PayloadProcessor(
|
||||
storage_manager,
|
||||
storage_provider,
|
||||
payload_parser,
|
||||
data_processor,
|
||||
)
|
||||
|
||||
@ -4,11 +4,12 @@ from kn_utils.logging import logger
|
||||
from typing import Tuple
|
||||
|
||||
from pyinfra.config import Config
|
||||
from pyinfra.storage.storage_info import get_storage_info_from_config, get_storage_info_from_endpoint, StorageInfo
|
||||
from pyinfra.storage.storage_info import get_storage_info_from_config, get_storage_info_from_endpoint, StorageInfo, \
|
||||
get_storage_from_storage_info
|
||||
from pyinfra.storage.storages.interface import Storage
|
||||
|
||||
|
||||
class StorageManager:
|
||||
class StorageProvider:
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.default_storage_info: StorageInfo = get_storage_info_from_config(config)
|
||||
@ -25,7 +26,7 @@ class StorageManager:
|
||||
@lru_cache(maxsize=32)
|
||||
def connect(self, x_tenant_id=None) -> Tuple[Storage, StorageInfo]:
|
||||
storage_info = self._get_storage_info(x_tenant_id)
|
||||
storage_connection = storage_info.get_storage()
|
||||
storage_connection = get_storage_from_storage_info(storage_info)
|
||||
return storage_connection, storage_info
|
||||
|
||||
def _get_storage_info(self, x_tenant_id=None):
|
||||
Loading…
x
Reference in New Issue
Block a user