update PayloadProcessor

- introduce storage cache to make every unique
storage connection only once
- add functionality to pass optional processing
kwargs in queue message like the operation key to
the processing function
This commit is contained in:
Julius Unverfehrt 2023-03-24 10:41:59 +01:00
parent d48e8108fd
commit 97309fe580
6 changed files with 155 additions and 61 deletions

View File

@ -3,6 +3,8 @@ from itertools import chain
from operator import itemgetter
from typing import Union, Sized
from funcy import project
from pyinfra.config import Config
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
@ -24,10 +26,13 @@ class QueueMessagePayload:
target_file_name: str
response_file_name: str
processing_kwargs: dict
class QueueMessagePayloadParser:
def __init__(self, file_extension_parser):
def __init__(self, file_extension_parser, allowed_processing_args=("operation",)):
self.parse_file_extensions = file_extension_parser
self.allowed_args = allowed_processing_args
def __call__(self, payload: dict) -> QueueMessagePayload:
"""Translate the queue message payload to the internal QueueMessagePayload object."""
@ -46,6 +51,8 @@ class QueueMessagePayloadParser:
target_file_name = f"{dossier_id}/{file_id}.{target_file_extension}"
response_file_name = f"{dossier_id}/{file_id}.{response_file_extension}"
processing_kwargs = project(payload, self.allowed_args)
return QueueMessagePayload(
dossier_id=dossier_id,
file_id=file_id,
@ -58,6 +65,7 @@ class QueueMessagePayloadParser:
response_compression_type=response_compression_type,
target_file_name=target_file_name,
response_file_name=response_file_name,
processing_kwargs=processing_kwargs,
)

View File

@ -1,9 +1,7 @@
import logging
from dataclasses import asdict
from functools import partial
from typing import Callable, Union, List
from funcy import compose
from typing import Callable, List
from pyinfra.config import get_config, Config
from pyinfra.payload_processing.monitor import get_monitor_from_config
@ -13,12 +11,12 @@ from pyinfra.payload_processing.payload import (
QueueMessagePayloadFormatter,
get_queue_message_payload_formatter,
)
from pyinfra.storage.storage import make_downloader, make_uploader, get_storage_from_storage_info
from pyinfra.storage.storage import make_downloader, make_uploader
from pyinfra.storage.storage_info import (
AzureStorageInfo,
S3StorageInfo,
get_storage_info_from_config,
get_storage_info_from_endpoint,
StorageInfo,
DefaultStorageCache,
)
logger = logging.getLogger()
@ -28,7 +26,8 @@ logger.setLevel(get_config().logging_level_root)
class PayloadProcessor:
def __init__(
self,
default_storage_info: Union[AzureStorageInfo, S3StorageInfo],
default_storage_info: StorageInfo,
storage_cache: DefaultStorageCache,
get_storage_info_from_tenant_id,
payload_parser: QueueMessagePayloadParser,
payload_formatter: QueueMessagePayloadFormatter,
@ -53,8 +52,7 @@ class PayloadProcessor:
self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id
self.default_storage_info = default_storage_info
# TODO: use lru-dict
self.storages = {}
self.storages = storage_cache
def __call__(self, queue_message_payload: dict) -> dict:
"""Processes a queue message payload.
@ -79,8 +77,8 @@ class PayloadProcessor:
logger.info(f"Processing {asdict(payload)} ...")
storage_info = self._get_storage_info(payload.x_tenant_id)
storage = self.storages[storage_info]
bucket = storage_info.bucket_name
storage = self._get_storage(storage_info)
download_file_to_process = make_downloader(
storage, bucket, payload.target_file_type, payload.target_compression_type
@ -90,11 +88,11 @@ class PayloadProcessor:
)
format_result_for_storage = partial(self.format_result_for_storage, payload)
processing_pipeline = compose(format_result_for_storage, self.process_data, download_file_to_process)
data = download_file_to_process(payload.target_file_name)
result: List[dict] = self.process_data(data, **payload.processing_kwargs)
formatted_result = format_result_for_storage(result)
result: List[dict] = processing_pipeline(payload.target_file_name)
upload_processing_result(payload.response_file_name, result)
upload_processing_result(payload.response_file_name, formatted_result)
return self.format_to_queue_message_response_body(payload)
@ -103,20 +101,13 @@ class PayloadProcessor:
return self.get_storage_info_from_tenant_id(x_tenant_id)
return self.default_storage_info
def _get_storage(self, storage_info):
if storage_info in self.storages:
return self.storages[storage_info]
else:
storage = get_storage_from_storage_info(storage_info)
self.storages[storage_info] = storage
return storage
def make_payload_processor(data_processor: Callable, config: Union[None, Config] = None) -> PayloadProcessor:
def make_payload_processor(data_processor: Callable, config: Config = None) -> PayloadProcessor:
"""Produces payload processor for queue manager."""
config = config or get_config()
default_storage_info: Union[AzureStorageInfo, S3StorageInfo] = get_storage_info_from_config(config)
storage_cache: DefaultStorageCache = DefaultStorageCache(max_size=10)
default_storage_info: StorageInfo = get_storage_info_from_config(config)
get_storage_info_from_tenant_id = partial(
get_storage_info_from_endpoint,
config.persistence_service_public_key,
@ -130,6 +121,7 @@ def make_payload_processor(data_processor: Callable, config: Union[None, Config]
return PayloadProcessor(
default_storage_info,
storage_cache,
get_storage_info_from_tenant_id,
payload_parser,
payload_formatter,

View File

@ -1,16 +1,11 @@
from functools import lru_cache, partial
from typing import Callable, Union
from typing import Callable
from azure.storage.blob import BlobServiceClient
from funcy import compose
from minio import Minio
from pyinfra.config import Config
from pyinfra.exception import UnknownStorageBackend
from pyinfra.storage.storage_info import AzureStorageInfo, S3StorageInfo, get_storage_info_from_config
from pyinfra.storage.storages.azure import AzureStorage
from pyinfra.storage.storage_info import get_storage_info_from_config, get_storage_from_storage_info
from pyinfra.storage.storages.interface import Storage
from pyinfra.storage.storages.s3 import S3Storage
from pyinfra.utils.compressing import get_decompressor, get_compressor
from pyinfra.utils.encoding import get_decoder, get_encoder
@ -23,23 +18,6 @@ def get_storage_from_config(config: Config) -> Storage:
return storage
def get_storage_from_storage_info(storage_info: Union[AzureStorageInfo, S3StorageInfo]) -> Storage:
if isinstance(storage_info, AzureStorageInfo):
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string))
elif isinstance(storage_info, S3StorageInfo):
return S3Storage(
Minio(
secure=storage_info.secure,
endpoint=storage_info.endpoint,
access_key=storage_info.access_key,
secret_key=storage_info.secret_key,
region=storage_info.region,
)
)
else:
raise UnknownStorageBackend()
def verify_existence(storage: Storage, bucket: str, file_name: str) -> str:
if not storage.exists(bucket, file_name):
raise FileNotFoundError(f"{file_name=} name not found on storage in {bucket=}.")

View File

@ -1,45 +1,113 @@
from dataclasses import dataclass
from typing import Union
from typing import Callable
import requests
from azure.storage.blob import BlobServiceClient
from minio import Minio
from pyinfra.config import Config
from pyinfra.exception import UnknownStorageBackend
from pyinfra.storage.storages.azure import AzureStorage
from pyinfra.storage.storages.interface import Storage
from pyinfra.storage.storages.s3 import S3Storage
from pyinfra.utils.cipher import decrypt
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
@dataclass(frozen=True)
class AzureStorageInfo:
connection_string: str
class StorageInfo:
bucket_name: str
@dataclass(frozen=True)
class AzureStorageInfo(StorageInfo):
connection_string: str
def __hash__(self):
return hash(self.connection_string)
def __eq__(self, other):
if not isinstance(other, AzureStorageInfo):
return False
return self.connection_string == other.connection_string
@dataclass(frozen=True)
class S3StorageInfo:
class S3StorageInfo(StorageInfo):
secure: bool
endpoint: str
access_key: str
secret_key: str
region: str
bucket_name: str
def __hash__(self):
return hash((self.secure, self.endpoint, self.access_key, self.secret_key, self.region))
def __eq__(self, other):
if not isinstance(other, S3StorageInfo):
return False
return (
self.secure == other.secure
and self.endpoint == other.endpoint
and self.access_key == other.access_key
and self.secret_key == other.secret_key
and self.region == other.region
)
def get_storage_info_from_endpoint(
public_key: str, endpoint: str, x_tenant_id: str
) -> Union[AzureStorageInfo, S3StorageInfo]:
# FIXME: parameterize port, host and public_key
public_key = "redaction"
def get_storage_from_storage_info(storage_info: StorageInfo) -> Storage:
if isinstance(storage_info, AzureStorageInfo):
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string))
elif isinstance(storage_info, S3StorageInfo):
return S3Storage(
Minio(
secure=storage_info.secure,
endpoint=storage_info.endpoint,
access_key=storage_info.access_key,
secret_key=storage_info.secret_key,
region=storage_info.region,
)
)
else:
raise UnknownStorageBackend()
class DefaultStorageCache(dict):
def __init__(self, max_size=None, get_value: Callable = get_storage_from_storage_info, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_size = max_size
self.get_value = get_value
self.access_history = []
def __getitem__(self, key: StorageInfo):
self.access_history.append(key)
value = super().get(key)
if not value:
self.__setitem__(key, value)
value = self[key]
return value
def __setitem__(self, key: StorageInfo, value: Storage = None):
self._delete_oldest_key_if_max_size_breached()
value = value or self.get_value(key)
super().__setitem__(key, value)
self.access_history.append(key)
def set_key(self, key: StorageInfo):
self.__setitem__(key)
def _delete_oldest_key_if_max_size_breached(self):
if len(self) >= self.max_size:
oldest_key = self.access_history.pop(0)
del self[oldest_key]
def get_storage_info_from_endpoint(public_key: str, endpoint: str, x_tenant_id: str) -> StorageInfo:
resp = requests.get(f"{endpoint}/{x_tenant_id}").json()
maybe_azure = resp.get("azureStorageConnection")
maybe_s3 = resp.get("azureStorageConnection")
maybe_s3 = resp.get("s3StorageConnection")
assert not (maybe_azure and maybe_s3)
if maybe_azure:
@ -66,7 +134,7 @@ def get_storage_info_from_endpoint(
return storage_info
def get_storage_info_from_config(config: Config) -> Union[AzureStorageInfo, S3StorageInfo]:
def get_storage_info_from_config(config: Config) -> StorageInfo:
if config.storage_backend == "s3":
storage_info = S3StorageInfo(
secure=config.storage_secure_connection,

View File

@ -21,6 +21,7 @@ def expected_parsed_payload(x_tenant_id):
response_compression_type="gz",
target_file_name="test/test.json.gz",
response_file_name="test/test.json.gz",
processing_kwargs={},
)

View File

@ -0,0 +1,47 @@
from dataclasses import asdict
import pytest
from pyinfra.storage.storage_info import AzureStorageInfo, StorageInfo, DefaultStorageCache
@pytest.fixture
def storage_infos():
return [
AzureStorageInfo(connection_string="first", bucket_name="Thorsten"),
AzureStorageInfo(connection_string="first", bucket_name="Klaus"),
AzureStorageInfo(connection_string="second", bucket_name="Thorsten"),
AzureStorageInfo(connection_string="third", bucket_name="Klaus"),
]
@pytest.fixture
def storage_cache(max_size):
def get_connection_from_storage_info_mock(storage_info: StorageInfo):
return asdict(storage_info)
return DefaultStorageCache(max_size=max_size, get_value=get_connection_from_storage_info_mock)
class TestStorageCache:
def test_same_storage_info_has_same_hash(self, storage_infos):
assert storage_infos[0].__hash__() == storage_infos[1].__hash__()
@pytest.mark.parametrize("max_size", [4])
def test_same_connection_different_bucket_does_not_create_new_connection(self, storage_cache, storage_infos):
value = storage_cache[storage_infos[0]]
assert value == asdict((storage_infos[0]))
value = storage_cache[storage_infos[1]]
assert value == asdict((storage_infos[0]))
@pytest.mark.parametrize("max_size", [2])
def test_max_size_breached_removes_oldest_key(self, storage_cache, storage_infos):
storage_cache.set_key(storage_infos[0])
storage_cache.set_key(storage_infos[2])
storage_cache.set_key(storage_infos[3])
assert len(storage_cache) == 2
assert storage_infos[0] not in storage_cache
assert storage_infos[2] in storage_cache
assert storage_infos[3] in storage_cache