update PayloadProcessor
- introduce storage cache to make every unique storage connection only once - add functionality to pass optional processing kwargs in queue message like the operation key to the processing function
This commit is contained in:
parent
d48e8108fd
commit
97309fe580
@ -3,6 +3,8 @@ from itertools import chain
|
||||
from operator import itemgetter
|
||||
from typing import Union, Sized
|
||||
|
||||
from funcy import project
|
||||
|
||||
from pyinfra.config import Config
|
||||
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
|
||||
|
||||
@ -24,10 +26,13 @@ class QueueMessagePayload:
|
||||
target_file_name: str
|
||||
response_file_name: str
|
||||
|
||||
processing_kwargs: dict
|
||||
|
||||
|
||||
class QueueMessagePayloadParser:
|
||||
def __init__(self, file_extension_parser):
|
||||
def __init__(self, file_extension_parser, allowed_processing_args=("operation",)):
|
||||
self.parse_file_extensions = file_extension_parser
|
||||
self.allowed_args = allowed_processing_args
|
||||
|
||||
def __call__(self, payload: dict) -> QueueMessagePayload:
|
||||
"""Translate the queue message payload to the internal QueueMessagePayload object."""
|
||||
@ -46,6 +51,8 @@ class QueueMessagePayloadParser:
|
||||
target_file_name = f"{dossier_id}/{file_id}.{target_file_extension}"
|
||||
response_file_name = f"{dossier_id}/{file_id}.{response_file_extension}"
|
||||
|
||||
processing_kwargs = project(payload, self.allowed_args)
|
||||
|
||||
return QueueMessagePayload(
|
||||
dossier_id=dossier_id,
|
||||
file_id=file_id,
|
||||
@ -58,6 +65,7 @@ class QueueMessagePayloadParser:
|
||||
response_compression_type=response_compression_type,
|
||||
target_file_name=target_file_name,
|
||||
response_file_name=response_file_name,
|
||||
processing_kwargs=processing_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from functools import partial
|
||||
from typing import Callable, Union, List
|
||||
|
||||
from funcy import compose
|
||||
from typing import Callable, List
|
||||
|
||||
from pyinfra.config import get_config, Config
|
||||
from pyinfra.payload_processing.monitor import get_monitor_from_config
|
||||
@ -13,12 +11,12 @@ from pyinfra.payload_processing.payload import (
|
||||
QueueMessagePayloadFormatter,
|
||||
get_queue_message_payload_formatter,
|
||||
)
|
||||
from pyinfra.storage.storage import make_downloader, make_uploader, get_storage_from_storage_info
|
||||
from pyinfra.storage.storage import make_downloader, make_uploader
|
||||
from pyinfra.storage.storage_info import (
|
||||
AzureStorageInfo,
|
||||
S3StorageInfo,
|
||||
get_storage_info_from_config,
|
||||
get_storage_info_from_endpoint,
|
||||
StorageInfo,
|
||||
DefaultStorageCache,
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
@ -28,7 +26,8 @@ logger.setLevel(get_config().logging_level_root)
|
||||
class PayloadProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
default_storage_info: Union[AzureStorageInfo, S3StorageInfo],
|
||||
default_storage_info: StorageInfo,
|
||||
storage_cache: DefaultStorageCache,
|
||||
get_storage_info_from_tenant_id,
|
||||
payload_parser: QueueMessagePayloadParser,
|
||||
payload_formatter: QueueMessagePayloadFormatter,
|
||||
@ -53,8 +52,7 @@ class PayloadProcessor:
|
||||
|
||||
self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id
|
||||
self.default_storage_info = default_storage_info
|
||||
# TODO: use lru-dict
|
||||
self.storages = {}
|
||||
self.storages = storage_cache
|
||||
|
||||
def __call__(self, queue_message_payload: dict) -> dict:
|
||||
"""Processes a queue message payload.
|
||||
@ -79,8 +77,8 @@ class PayloadProcessor:
|
||||
logger.info(f"Processing {asdict(payload)} ...")
|
||||
|
||||
storage_info = self._get_storage_info(payload.x_tenant_id)
|
||||
storage = self.storages[storage_info]
|
||||
bucket = storage_info.bucket_name
|
||||
storage = self._get_storage(storage_info)
|
||||
|
||||
download_file_to_process = make_downloader(
|
||||
storage, bucket, payload.target_file_type, payload.target_compression_type
|
||||
@ -90,11 +88,11 @@ class PayloadProcessor:
|
||||
)
|
||||
format_result_for_storage = partial(self.format_result_for_storage, payload)
|
||||
|
||||
processing_pipeline = compose(format_result_for_storage, self.process_data, download_file_to_process)
|
||||
data = download_file_to_process(payload.target_file_name)
|
||||
result: List[dict] = self.process_data(data, **payload.processing_kwargs)
|
||||
formatted_result = format_result_for_storage(result)
|
||||
|
||||
result: List[dict] = processing_pipeline(payload.target_file_name)
|
||||
|
||||
upload_processing_result(payload.response_file_name, result)
|
||||
upload_processing_result(payload.response_file_name, formatted_result)
|
||||
|
||||
return self.format_to_queue_message_response_body(payload)
|
||||
|
||||
@ -103,20 +101,13 @@ class PayloadProcessor:
|
||||
return self.get_storage_info_from_tenant_id(x_tenant_id)
|
||||
return self.default_storage_info
|
||||
|
||||
def _get_storage(self, storage_info):
|
||||
if storage_info in self.storages:
|
||||
return self.storages[storage_info]
|
||||
else:
|
||||
storage = get_storage_from_storage_info(storage_info)
|
||||
self.storages[storage_info] = storage
|
||||
return storage
|
||||
|
||||
|
||||
def make_payload_processor(data_processor: Callable, config: Union[None, Config] = None) -> PayloadProcessor:
|
||||
def make_payload_processor(data_processor: Callable, config: Config = None) -> PayloadProcessor:
|
||||
"""Produces payload processor for queue manager."""
|
||||
config = config or get_config()
|
||||
|
||||
default_storage_info: Union[AzureStorageInfo, S3StorageInfo] = get_storage_info_from_config(config)
|
||||
storage_cache: DefaultStorageCache = DefaultStorageCache(max_size=10)
|
||||
default_storage_info: StorageInfo = get_storage_info_from_config(config)
|
||||
get_storage_info_from_tenant_id = partial(
|
||||
get_storage_info_from_endpoint,
|
||||
config.persistence_service_public_key,
|
||||
@ -130,6 +121,7 @@ def make_payload_processor(data_processor: Callable, config: Union[None, Config]
|
||||
|
||||
return PayloadProcessor(
|
||||
default_storage_info,
|
||||
storage_cache,
|
||||
get_storage_info_from_tenant_id,
|
||||
payload_parser,
|
||||
payload_formatter,
|
||||
|
||||
@ -1,16 +1,11 @@
|
||||
from functools import lru_cache, partial
|
||||
from typing import Callable, Union
|
||||
from typing import Callable
|
||||
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from funcy import compose
|
||||
from minio import Minio
|
||||
|
||||
from pyinfra.config import Config
|
||||
from pyinfra.exception import UnknownStorageBackend
|
||||
from pyinfra.storage.storage_info import AzureStorageInfo, S3StorageInfo, get_storage_info_from_config
|
||||
from pyinfra.storage.storages.azure import AzureStorage
|
||||
from pyinfra.storage.storage_info import get_storage_info_from_config, get_storage_from_storage_info
|
||||
from pyinfra.storage.storages.interface import Storage
|
||||
from pyinfra.storage.storages.s3 import S3Storage
|
||||
from pyinfra.utils.compressing import get_decompressor, get_compressor
|
||||
from pyinfra.utils.encoding import get_decoder, get_encoder
|
||||
|
||||
@ -23,23 +18,6 @@ def get_storage_from_config(config: Config) -> Storage:
|
||||
return storage
|
||||
|
||||
|
||||
def get_storage_from_storage_info(storage_info: Union[AzureStorageInfo, S3StorageInfo]) -> Storage:
|
||||
if isinstance(storage_info, AzureStorageInfo):
|
||||
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string))
|
||||
elif isinstance(storage_info, S3StorageInfo):
|
||||
return S3Storage(
|
||||
Minio(
|
||||
secure=storage_info.secure,
|
||||
endpoint=storage_info.endpoint,
|
||||
access_key=storage_info.access_key,
|
||||
secret_key=storage_info.secret_key,
|
||||
region=storage_info.region,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise UnknownStorageBackend()
|
||||
|
||||
|
||||
def verify_existence(storage: Storage, bucket: str, file_name: str) -> str:
|
||||
if not storage.exists(bucket, file_name):
|
||||
raise FileNotFoundError(f"{file_name=} name not found on storage in {bucket=}.")
|
||||
|
||||
@ -1,45 +1,113 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
from typing import Callable
|
||||
|
||||
import requests
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from minio import Minio
|
||||
|
||||
from pyinfra.config import Config
|
||||
from pyinfra.exception import UnknownStorageBackend
|
||||
from pyinfra.storage.storages.azure import AzureStorage
|
||||
from pyinfra.storage.storages.interface import Storage
|
||||
from pyinfra.storage.storages.s3 import S3Storage
|
||||
from pyinfra.utils.cipher import decrypt
|
||||
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AzureStorageInfo:
|
||||
connection_string: str
|
||||
class StorageInfo:
|
||||
bucket_name: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AzureStorageInfo(StorageInfo):
|
||||
connection_string: str
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.connection_string)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, AzureStorageInfo):
|
||||
return False
|
||||
return self.connection_string == other.connection_string
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class S3StorageInfo:
|
||||
class S3StorageInfo(StorageInfo):
|
||||
secure: bool
|
||||
endpoint: str
|
||||
access_key: str
|
||||
secret_key: str
|
||||
region: str
|
||||
bucket_name: str
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.secure, self.endpoint, self.access_key, self.secret_key, self.region))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, S3StorageInfo):
|
||||
return False
|
||||
return (
|
||||
self.secure == other.secure
|
||||
and self.endpoint == other.endpoint
|
||||
and self.access_key == other.access_key
|
||||
and self.secret_key == other.secret_key
|
||||
and self.region == other.region
|
||||
)
|
||||
|
||||
def get_storage_info_from_endpoint(
|
||||
public_key: str, endpoint: str, x_tenant_id: str
|
||||
) -> Union[AzureStorageInfo, S3StorageInfo]:
|
||||
# FIXME: parameterize port, host and public_key
|
||||
public_key = "redaction"
|
||||
|
||||
def get_storage_from_storage_info(storage_info: StorageInfo) -> Storage:
|
||||
if isinstance(storage_info, AzureStorageInfo):
|
||||
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string))
|
||||
elif isinstance(storage_info, S3StorageInfo):
|
||||
return S3Storage(
|
||||
Minio(
|
||||
secure=storage_info.secure,
|
||||
endpoint=storage_info.endpoint,
|
||||
access_key=storage_info.access_key,
|
||||
secret_key=storage_info.secret_key,
|
||||
region=storage_info.region,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise UnknownStorageBackend()
|
||||
|
||||
|
||||
class DefaultStorageCache(dict):
|
||||
def __init__(self, max_size=None, get_value: Callable = get_storage_from_storage_info, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.max_size = max_size
|
||||
self.get_value = get_value
|
||||
self.access_history = []
|
||||
|
||||
def __getitem__(self, key: StorageInfo):
|
||||
self.access_history.append(key)
|
||||
value = super().get(key)
|
||||
if not value:
|
||||
self.__setitem__(key, value)
|
||||
value = self[key]
|
||||
|
||||
return value
|
||||
|
||||
def __setitem__(self, key: StorageInfo, value: Storage = None):
|
||||
self._delete_oldest_key_if_max_size_breached()
|
||||
value = value or self.get_value(key)
|
||||
super().__setitem__(key, value)
|
||||
self.access_history.append(key)
|
||||
|
||||
def set_key(self, key: StorageInfo):
|
||||
self.__setitem__(key)
|
||||
|
||||
def _delete_oldest_key_if_max_size_breached(self):
|
||||
if len(self) >= self.max_size:
|
||||
oldest_key = self.access_history.pop(0)
|
||||
del self[oldest_key]
|
||||
|
||||
|
||||
def get_storage_info_from_endpoint(public_key: str, endpoint: str, x_tenant_id: str) -> StorageInfo:
|
||||
resp = requests.get(f"{endpoint}/{x_tenant_id}").json()
|
||||
|
||||
maybe_azure = resp.get("azureStorageConnection")
|
||||
maybe_s3 = resp.get("azureStorageConnection")
|
||||
maybe_s3 = resp.get("s3StorageConnection")
|
||||
assert not (maybe_azure and maybe_s3)
|
||||
|
||||
if maybe_azure:
|
||||
@ -66,7 +134,7 @@ def get_storage_info_from_endpoint(
|
||||
return storage_info
|
||||
|
||||
|
||||
def get_storage_info_from_config(config: Config) -> Union[AzureStorageInfo, S3StorageInfo]:
|
||||
def get_storage_info_from_config(config: Config) -> StorageInfo:
|
||||
if config.storage_backend == "s3":
|
||||
storage_info = S3StorageInfo(
|
||||
secure=config.storage_secure_connection,
|
||||
|
||||
@ -21,6 +21,7 @@ def expected_parsed_payload(x_tenant_id):
|
||||
response_compression_type="gz",
|
||||
target_file_name="test/test.json.gz",
|
||||
response_file_name="test/test.json.gz",
|
||||
processing_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
|
||||
47
tests/storage_cache_test.py
Normal file
47
tests/storage_cache_test.py
Normal file
@ -0,0 +1,47 @@
|
||||
from dataclasses import asdict
|
||||
|
||||
import pytest
|
||||
|
||||
from pyinfra.storage.storage_info import AzureStorageInfo, StorageInfo, DefaultStorageCache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_infos():
|
||||
return [
|
||||
AzureStorageInfo(connection_string="first", bucket_name="Thorsten"),
|
||||
AzureStorageInfo(connection_string="first", bucket_name="Klaus"),
|
||||
AzureStorageInfo(connection_string="second", bucket_name="Thorsten"),
|
||||
AzureStorageInfo(connection_string="third", bucket_name="Klaus"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_cache(max_size):
|
||||
def get_connection_from_storage_info_mock(storage_info: StorageInfo):
|
||||
return asdict(storage_info)
|
||||
|
||||
return DefaultStorageCache(max_size=max_size, get_value=get_connection_from_storage_info_mock)
|
||||
|
||||
|
||||
class TestStorageCache:
|
||||
def test_same_storage_info_has_same_hash(self, storage_infos):
|
||||
assert storage_infos[0].__hash__() == storage_infos[1].__hash__()
|
||||
|
||||
@pytest.mark.parametrize("max_size", [4])
|
||||
def test_same_connection_different_bucket_does_not_create_new_connection(self, storage_cache, storage_infos):
|
||||
value = storage_cache[storage_infos[0]]
|
||||
assert value == asdict((storage_infos[0]))
|
||||
|
||||
value = storage_cache[storage_infos[1]]
|
||||
assert value == asdict((storage_infos[0]))
|
||||
|
||||
@pytest.mark.parametrize("max_size", [2])
|
||||
def test_max_size_breached_removes_oldest_key(self, storage_cache, storage_infos):
|
||||
storage_cache.set_key(storage_infos[0])
|
||||
storage_cache.set_key(storage_infos[2])
|
||||
storage_cache.set_key(storage_infos[3])
|
||||
|
||||
assert len(storage_cache) == 2
|
||||
assert storage_infos[0] not in storage_cache
|
||||
assert storage_infos[2] in storage_cache
|
||||
assert storage_infos[3] in storage_cache
|
||||
Loading…
x
Reference in New Issue
Block a user