Pull request #68: RED-6273 multi tenant storage

Merge in RR/pyinfra from RED-6273-multi-tenant-storage to master

Squashed commit of the following:

commit 0fead1f8b59c9187330879b4e48d48355885c27c
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date:   Tue Mar 28 15:02:22 2023 +0200

    fix typos

commit 892a803726946876f8b8cd7905a0e73c419b2fb1
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Tue Mar 28 14:41:49 2023 +0200

    Refactoring

    Replace custom storage caching logic with LRU decorator

commit eafcd90260731e3360ce960571f07dee8f521327
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date:   Fri Mar 24 12:50:13 2023 +0100

    fix bug in storage connection from endpoint

commit d0c9fb5b7d1c55ae2f90e8faa1efec9f7587c26a
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date:   Fri Mar 24 11:49:34 2023 +0100

    add logs to PayloadProcessor

    - set log messages to determine if x-tenant
    storage connection is working

commit 97309fe58037b90469cf7a3de342d4749a0edfde
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date:   Fri Mar 24 10:41:59 2023 +0100

    update PayloadProcessor

    - introduce storage cache to make every unique
    storage connection only once
    - add functionality to pass optional processing
    kwargs in queue message like the operation key to
    the processing function

commit d48e8108fdc0d463c89aaa0d672061ab7dca83a0
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date:   Wed Mar 22 13:34:43 2023 +0100

    add multi-tenant storage connection 1st iteration

    - forward x-tenant-id from queue message header to
    payload processor
    - add functions to receive storage infos from an
    endpoint or the config. This enables hashing and
    caching of connections created from these infos
    - add function to initialize storage connections
    from storage infos
    - streamline and refactor tests to make them more
    readable and robust and to make it easier to add
     new tests
    - update payload processor with first iteration
    of multi tenancy storage connection support
    with connection caching and backwards compability

commit 52c047c47b98e62d0b834a9b9b6c0e2bb0db41e5
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date:   Tue Mar 21 15:35:57 2023 +0100

    add AES/GCM cipher functions

    - decrypt x-tenant storage connection strings
This commit is contained in:
Julius Unverfehrt 2023-03-28 15:04:14 +02:00
parent 0f24a7f26d
commit 793a427c50
26 changed files with 547 additions and 395 deletions

45
poetry.lock generated
View File

@ -1360,6 +1360,49 @@ files = [
{file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
]
[[package]]
name = "pycryptodome"
version = "3.17"
description = "Cryptographic library for Python"
category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
files = [
{file = "pycryptodome-3.17-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:2c5631204ebcc7ae33d11c43037b2dafe25e2ab9c1de6448eb6502ac69c19a56"},
{file = "pycryptodome-3.17-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:04779cc588ad8f13c80a060b0b1c9d1c203d051d8a43879117fe6b8aaf1cd3fa"},
{file = "pycryptodome-3.17-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:f812d58c5af06d939b2baccdda614a3ffd80531a26e5faca2c9f8b1770b2b7af"},
{file = "pycryptodome-3.17-cp27-cp27m-manylinux2014_aarch64.whl", hash = "sha256:9453b4e21e752df8737fdffac619e93c9f0ec55ead9a45df782055eb95ef37d9"},
{file = "pycryptodome-3.17-cp27-cp27m-musllinux_1_1_aarch64.whl", hash = "sha256:121d61663267f73692e8bde5ec0d23c9146465a0d75cad75c34f75c752527b01"},
{file = "pycryptodome-3.17-cp27-cp27m-win32.whl", hash = "sha256:ba2d4fcb844c6ba5df4bbfee9352ad5352c5ae939ac450e06cdceff653280450"},
{file = "pycryptodome-3.17-cp27-cp27m-win_amd64.whl", hash = "sha256:87e2ca3aa557781447428c4b6c8c937f10ff215202ab40ece5c13a82555c10d6"},
{file = "pycryptodome-3.17-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:f44c0d28716d950135ff21505f2c764498eda9d8806b7c78764165848aa419bc"},
{file = "pycryptodome-3.17-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:5a790bc045003d89d42e3b9cb3cc938c8561a57a88aaa5691512e8540d1ae79c"},
{file = "pycryptodome-3.17-cp27-cp27mu-manylinux2014_aarch64.whl", hash = "sha256:d086d46774e27b280e4cece8ab3d87299cf0d39063f00f1e9290d096adc5662a"},
{file = "pycryptodome-3.17-cp27-cp27mu-musllinux_1_1_aarch64.whl", hash = "sha256:5587803d5b66dfd99e7caa31ed91fba0fdee3661c5d93684028ad6653fce725f"},
{file = "pycryptodome-3.17-cp35-abi3-macosx_10_9_universal2.whl", hash = "sha256:e7debd9c439e7b84f53be3cf4ba8b75b3d0b6e6015212355d6daf44ac672e210"},
{file = "pycryptodome-3.17-cp35-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ca1ceb6303be1282148f04ac21cebeebdb4152590842159877778f9cf1634f09"},
{file = "pycryptodome-3.17-cp35-abi3-manylinux2014_aarch64.whl", hash = "sha256:dc22cc00f804485a3c2a7e2010d9f14a705555f67020eb083e833cabd5bd82e4"},
{file = "pycryptodome-3.17-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80ea8333b6a5f2d9e856ff2293dba2e3e661197f90bf0f4d5a82a0a6bc83a626"},
{file = "pycryptodome-3.17-cp35-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c133f6721fba313722a018392a91e3c69d3706ae723484841752559e71d69dc6"},
{file = "pycryptodome-3.17-cp35-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:333306eaea01fde50a73c4619e25631e56c4c61bd0fb0a2346479e67e3d3a820"},
{file = "pycryptodome-3.17-cp35-abi3-musllinux_1_1_i686.whl", hash = "sha256:1a30f51b990994491cec2d7d237924e5b6bd0d445da9337d77de384ad7f254f9"},
{file = "pycryptodome-3.17-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:909e36a43fe4a8a3163e9c7fc103867825d14a2ecb852a63d3905250b308a4e5"},
{file = "pycryptodome-3.17-cp35-abi3-win32.whl", hash = "sha256:a3228728a3808bc9f18c1797ec1179a0efb5068c817b2ffcf6bcd012494dffb2"},
{file = "pycryptodome-3.17-cp35-abi3-win_amd64.whl", hash = "sha256:9ec565e89a6b400eca814f28d78a9ef3f15aea1df74d95b28b7720739b28f37f"},
{file = "pycryptodome-3.17-pp27-pypy_73-macosx_10_9_x86_64.whl", hash = "sha256:e1819b67bcf6ca48341e9b03c2e45b1c891fa8eb1a8458482d14c2805c9616f2"},
{file = "pycryptodome-3.17-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:f8e550caf52472ae9126953415e4fc554ab53049a5691c45b8816895c632e4d7"},
{file = "pycryptodome-3.17-pp27-pypy_73-win32.whl", hash = "sha256:afbcdb0eda20a0e1d44e3a1ad6d4ec3c959210f4b48cabc0e387a282f4c7deb8"},
{file = "pycryptodome-3.17-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a74f45aee8c5cc4d533e585e0e596e9f78521e1543a302870a27b0ae2106381e"},
{file = "pycryptodome-3.17-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38bbd6717eac084408b4094174c0805bdbaba1f57fc250fd0309ae5ec9ed7e09"},
{file = "pycryptodome-3.17-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f68d6c8ea2974a571cacb7014dbaada21063a0375318d88ac1f9300bc81e93c3"},
{file = "pycryptodome-3.17-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:8198f2b04c39d817b206ebe0db25a6653bb5f463c2319d6f6d9a80d012ac1e37"},
{file = "pycryptodome-3.17-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3a232474cd89d3f51e4295abe248a8b95d0332d153bf46444e415409070aae1e"},
{file = "pycryptodome-3.17-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4992ec965606054e8326e83db1c8654f0549cdb26fce1898dc1a20bc7684ec1c"},
{file = "pycryptodome-3.17-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:53068e33c74f3b93a8158dacaa5d0f82d254a81b1002e0cd342be89fcb3433eb"},
{file = "pycryptodome-3.17-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:74794a2e2896cd0cf56fdc9db61ef755fa812b4a4900fa46c49045663a92b8d0"},
{file = "pycryptodome-3.17.tar.gz", hash = "sha256:bce2e2d8e82fcf972005652371a3e8731956a0c1fbb719cc897943b3695ad91b"},
]
[[package]]
name = "pygments"
version = "2.14.0"
@ -2041,4 +2084,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
[metadata]
lock-version = "2.0"
python-versions = "~3.8"
content-hash = "d5017298c5fddc7d919a319fec7e050e8f7603c0c11c08d5c6c3115e751083e2"
content-hash = "43b21cbd83ef95d9d28a72856f9b2d78ad4f0c8c740af58bf8b40d1c61dba50d"

View File

@ -92,6 +92,10 @@ class Config:
self.allowed_file_types = ["json", "pdf"]
self.allowed_compression_types = ["gz"]
# config for x-tenant-endpoint to receive storage connection information per tenant
self.persistence_service_public_key = "redaction"
self.persistence_service_tenant_endpoint = "http://persistence-service-v1:8080/internal-api/tenants"
# Value to see if we should write a consumer token to a file
self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False")

View File

@ -1,2 +1,5 @@
class ProcessingFailure(RuntimeError):
pass
class UnknownStorageBackend(Exception):
pass

View File

@ -12,7 +12,7 @@ logger = logging.getLogger()
class PrometheusMonitor:
def __init__(self, prefix: str, port=8080, host="127.0.0.1"):
def __init__(self, prefix: str, host: str, port: int):
"""Register the monitoring metrics and start a webserver where they can be scraped at the endpoint
http://{host}:{port}/prometheus
@ -23,12 +23,9 @@ class PrometheusMonitor:
self.registry = CollectorRegistry()
self.entity_processing_time_sum = Summary(
f"{prefix}_processing_time",
"Summed up average processing time per entity observed",
f"{prefix}_processing_time", "Summed up average processing time per entity observed", registry=self.registry
)
self.registry.register(self.entity_processing_time_sum)
start_http_server(port, host, self.registry)
def __call__(self, process_fn: Callable) -> Callable:
@ -58,8 +55,8 @@ class PrometheusMonitor:
return inner
def get_monitor(config: Config) -> Callable:
def get_monitor_from_config(config: Config) -> Callable:
if config.monitoring_enabled:
return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_port", "prometheus_host")(config))
return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_host", "prometheus_port")(config))
else:
return identity

View File

@ -3,6 +3,8 @@ from itertools import chain
from operator import itemgetter
from typing import Union, Sized
from funcy import project
from pyinfra.config import Config
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
@ -11,6 +13,8 @@ from pyinfra.utils.file_extension_parsing import make_file_extension_parser
class QueueMessagePayload:
dossier_id: str
file_id: str
x_tenant_id: Union[str, None]
target_file_extension: str
response_file_extension: str
@ -22,10 +26,13 @@ class QueueMessagePayload:
target_file_name: str
response_file_name: str
processing_kwargs: dict
class QueueMessagePayloadParser:
def __init__(self, file_extension_parser):
def __init__(self, file_extension_parser, allowed_processing_args=("operation",)):
self.parse_file_extensions = file_extension_parser
self.allowed_args = allowed_processing_args
def __call__(self, payload: dict) -> QueueMessagePayload:
"""Translate the queue message payload to the internal QueueMessagePayload object."""
@ -35,6 +42,7 @@ class QueueMessagePayloadParser:
dossier_id, file_id, target_file_extension, response_file_extension = itemgetter(
"dossierId", "fileId", "targetFileExtension", "responseFileExtension"
)(payload)
x_tenant_id = payload.get("X-TENANT-ID")
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
map(self.parse_file_extensions, [target_file_extension, response_file_extension])
@ -43,9 +51,12 @@ class QueueMessagePayloadParser:
target_file_name = f"{dossier_id}/{file_id}.{target_file_extension}"
response_file_name = f"{dossier_id}/{file_id}.{response_file_extension}"
processing_kwargs = project(payload, self.allowed_args)
return QueueMessagePayload(
dossier_id=dossier_id,
file_id=file_id,
x_tenant_id=x_tenant_id,
target_file_extension=target_file_extension,
response_file_extension=response_file_extension,
target_file_type=target_file_type,
@ -54,6 +65,7 @@ class QueueMessagePayloadParser:
response_compression_type=response_compression_type,
target_file_name=target_file_name,
response_file_name=response_file_name,
processing_kwargs=processing_kwargs,
)

View File

@ -1,21 +1,23 @@
import logging
from dataclasses import asdict
from functools import partial
from typing import Callable, Union, List
from funcy import compose
from typing import Callable, List
from pyinfra.config import get_config, Config
from pyinfra.payload_processing.monitor import get_monitor
from pyinfra.payload_processing.monitor import get_monitor_from_config
from pyinfra.payload_processing.payload import (
QueueMessagePayloadParser,
get_queue_message_payload_parser,
QueueMessagePayloadFormatter,
get_queue_message_payload_formatter,
)
from pyinfra.storage import get_storage
from pyinfra.storage.storage import make_downloader, make_uploader
from pyinfra.storage.storages.interface import Storage
from pyinfra.storage.storage_info import (
get_storage_info_from_config,
get_storage_info_from_endpoint,
StorageInfo,
get_storage_from_storage_info,
)
logger = logging.getLogger()
logger.setLevel(get_config().logging_level_root)
@ -24,8 +26,8 @@ logger.setLevel(get_config().logging_level_root)
class PayloadProcessor:
def __init__(
self,
storage: Storage,
bucket: str,
default_storage_info: StorageInfo,
get_storage_info_from_tenant_id,
payload_parser: QueueMessagePayloadParser,
payload_formatter: QueueMessagePayloadFormatter,
data_processor: Callable,
@ -33,21 +35,22 @@ class PayloadProcessor:
"""Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps.
Args:
storage: The storage to use for downloading and uploading files
bucket: The bucket to use for downloading and uploading files
default_storage_info: The default storage info used to create the storage connection. This is only used if
x_tenant_id is not provided in the queue payload.
get_storage_info_from_tenant_id: Callable to acquire storage info from a given tenant id.
payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object
payload_formatter: Formatter for the storage upload result and the queue message response body
data_processor: The analysis function to be called with the downloaded file
NOTE: the result of the analysis function has to be Sized, e.g. a dict or a list to be able to upload it
and to be able to monitor the processing time.
NOTE: The result of the analysis function has to be an instance of `Sized`, e.g. a dict or a list to be
able to upload it and to be able to monitor the processing time.
"""
self.parse_payload = payload_parser
self.format_result_for_storage = payload_formatter.format_service_processing_result_for_storage
self.format_to_queue_message_response_body = payload_formatter.format_to_queue_message_response_body
self.process_data = data_processor
self.partial_download_fn = partial(make_downloader, storage, bucket)
self.partial_upload_fn = partial(make_uploader, storage, bucket)
self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id
self.default_storage_info = default_storage_info
def __call__(self, queue_message_payload: dict) -> dict:
"""Processes a queue message payload.
@ -71,29 +74,58 @@ class PayloadProcessor:
payload = self.parse_payload(queue_message_payload)
logger.info(f"Processing {asdict(payload)} ...")
download_file_to_process = self.partial_download_fn(payload.target_file_type, payload.target_compression_type)
upload_processing_result = self.partial_upload_fn(payload.response_file_type, payload.response_compression_type)
storage_info = self._get_storage_info(payload.x_tenant_id)
storage = get_storage_from_storage_info(storage_info)
bucket = storage_info.bucket_name
download_file_to_process = make_downloader(
storage, bucket, payload.target_file_type, payload.target_compression_type
)
upload_processing_result = make_uploader(
storage, bucket, payload.response_file_type, payload.response_compression_type
)
format_result_for_storage = partial(self.format_result_for_storage, payload)
processing_pipeline = compose(format_result_for_storage, self.process_data, download_file_to_process)
data = download_file_to_process(payload.target_file_name)
result: List[dict] = self.process_data(data, **payload.processing_kwargs)
formatted_result = format_result_for_storage(result)
result: List[dict] = processing_pipeline(payload.target_file_name)
upload_processing_result(payload.response_file_name, result)
upload_processing_result(payload.response_file_name, formatted_result)
return self.format_to_queue_message_response_body(payload)
def _get_storage_info(self, x_tenant_id=None):
if x_tenant_id:
storage_info = self.get_storage_info_from_tenant_id(x_tenant_id)
logger.info(f"Received {storage_info.__class__.__name__} for {x_tenant_id} from endpoint.")
logger.debug(f"{asdict(storage_info)}")
else:
storage_info = self.default_storage_info
logger.info(f"Using local default {storage_info.__class__.__name__} for {x_tenant_id}.")
logger.debug(f"{asdict(storage_info)}")
return storage_info
def make_payload_processor(data_processor: Callable, config: Union[None, Config] = None) -> PayloadProcessor:
def make_payload_processor(data_processor: Callable, config: Config = None) -> PayloadProcessor:
"""Produces payload processor for queue manager."""
config = config or get_config()
bucket: str = config.storage_bucket
storage: Storage = get_storage(config)
monitor = get_monitor(config)
default_storage_info: StorageInfo = get_storage_info_from_config(config)
get_storage_info_from_tenant_id = partial(
get_storage_info_from_endpoint,
config.persistence_service_public_key,
config.persistence_service_tenant_endpoint,
)
monitor = get_monitor_from_config(config)
payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config)
payload_formatter: QueueMessagePayloadFormatter = get_queue_message_payload_formatter()
data_processor = monitor(data_processor)
return PayloadProcessor(storage, bucket, payload_parser, payload_formatter, data_processor)
return PayloadProcessor(
default_storage_info,
get_storage_info_from_tenant_id,
payload_parser,
payload_formatter,
data_processor,
)

View File

@ -12,6 +12,7 @@ from pika.adapters.blocking_connection import BlockingChannel
from pyinfra.config import Config
from pyinfra.exception import ProcessingFailure
from pyinfra.payload_processing.processor import PayloadProcessor
from pyinfra.utils.dict import safe_project
CONFIG = Config()
@ -164,8 +165,8 @@ class QueueManager:
except Exception as err:
raise ProcessingFailure("QueueMessagePayload processing failed") from err
def acknowledge_message_and_publish_response(frame, properties, response_body):
response_properties = pika.BasicProperties(headers=properties.headers) if properties.headers else None
def acknowledge_message_and_publish_response(frame, headers, response_body):
response_properties = pika.BasicProperties(headers=headers) if headers else None
self._channel.basic_publish("", self._output_queue, json.dumps(response_body).encode(), response_properties)
self.logger.info(
"Result published, acknowledging incoming message with delivery_tag %s",
@ -190,12 +191,15 @@ class QueueManager:
try:
self.logger.debug("Processing (%s, %s, %s)", frame, properties, body)
processing_result = process_message_body_and_await_result(json.loads(body))
filtered_message_headers = safe_project(properties.headers, ["X-TENANT-ID"]) # TODO: parametrize key?
message_body = {**json.loads(body), **filtered_message_headers}
processing_result = process_message_body_and_await_result(message_body)
self.logger.info(
"Processed message with delivery_tag %s, publishing result to result-queue",
frame.delivery_tag,
)
acknowledge_message_and_publish_response(frame, properties, processing_result)
acknowledge_message_and_publish_response(frame, filtered_message_headers, processing_result)
except ProcessingFailure:
self.logger.info(

View File

@ -1,3 +1,3 @@
from pyinfra.storage.storage import get_storage
from pyinfra.storage.storage import get_storage_from_config
__all__ = ["get_storage"]
__all__ = ["get_storage_from_config"]

View File

@ -4,21 +4,16 @@ from typing import Callable
from funcy import compose
from pyinfra.config import Config
from pyinfra.storage.storages.azure import get_azure_storage
from pyinfra.storage.storages.s3 import get_s3_storage
from pyinfra.storage.storage_info import get_storage_info_from_config, get_storage_from_storage_info
from pyinfra.storage.storages.interface import Storage
from pyinfra.utils.compressing import get_decompressor, get_compressor
from pyinfra.utils.encoding import get_decoder, get_encoder
def get_storage(config: Config) -> Storage:
def get_storage_from_config(config: Config) -> Storage:
if config.storage_backend == "s3":
storage = get_s3_storage(config)
elif config.storage_backend == "azure":
storage = get_azure_storage(config)
else:
raise Exception(f"Unknown storage backend '{config.storage_backend}'.")
storage_info = get_storage_info_from_config(config)
storage = get_storage_from_storage_info(storage_info)
return storage

View File

@ -0,0 +1,125 @@
from dataclasses import dataclass
import requests
from azure.storage.blob import BlobServiceClient
from minio import Minio
from pyinfra.config import Config
from pyinfra.exception import UnknownStorageBackend
from pyinfra.storage.storages.azure import AzureStorage
from pyinfra.storage.storages.interface import Storage
from pyinfra.storage.storages.s3 import S3Storage
from pyinfra.utils.cipher import decrypt
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
@dataclass(frozen=True)
class StorageInfo:
bucket_name: str
@dataclass(frozen=True)
class AzureStorageInfo(StorageInfo):
connection_string: str
def __hash__(self):
return hash(self.connection_string)
def __eq__(self, other):
if not isinstance(other, AzureStorageInfo):
return False
return self.connection_string == other.connection_string
@dataclass(frozen=True)
class S3StorageInfo(StorageInfo):
secure: bool
endpoint: str
access_key: str
secret_key: str
region: str
def __hash__(self):
return hash((self.secure, self.endpoint, self.access_key, self.secret_key, self.region))
def __eq__(self, other):
if not isinstance(other, S3StorageInfo):
return False
return (
self.secure == other.secure
and self.endpoint == other.endpoint
and self.access_key == other.access_key
and self.secret_key == other.secret_key
and self.region == other.region
)
def get_storage_from_storage_info(storage_info: StorageInfo) -> Storage:
if isinstance(storage_info, AzureStorageInfo):
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string))
elif isinstance(storage_info, S3StorageInfo):
return S3Storage(
Minio(
secure=storage_info.secure,
endpoint=storage_info.endpoint,
access_key=storage_info.access_key,
secret_key=storage_info.secret_key,
region=storage_info.region,
)
)
else:
raise UnknownStorageBackend()
def get_storage_info_from_endpoint(public_key: str, endpoint: str, x_tenant_id: str) -> StorageInfo:
resp = requests.get(f"{endpoint}/{x_tenant_id}").json()
maybe_azure = resp.get("azureStorageConnection")
maybe_s3 = resp.get("s3StorageConnection")
assert not (maybe_azure and maybe_s3)
if maybe_azure:
connection_string = decrypt(public_key, maybe_azure["connectionString"])
storage_info = AzureStorageInfo(
connection_string=connection_string,
bucket_name=maybe_azure["containerName"],
)
elif maybe_s3:
secure, endpoint = validate_and_parse_s3_endpoint(maybe_s3["endpoint"])
secret = decrypt(public_key, maybe_s3["secret"])
storage_info = S3StorageInfo(
secure=secure,
endpoint=endpoint,
access_key=maybe_s3["key"],
secret_key=secret,
region=maybe_s3["region"],
bucket_name=maybe_s3["bucketName"],
)
else:
raise UnknownStorageBackend()
return storage_info
def get_storage_info_from_config(config: Config) -> StorageInfo:
if config.storage_backend == "s3":
storage_info = S3StorageInfo(
secure=config.storage_secure_connection,
endpoint=config.storage_endpoint,
access_key=config.storage_key,
secret_key=config.storage_secret,
region=config.storage_region,
bucket_name=config.storage_bucket,
)
elif config.storage_backend == "azure":
storage_info = AzureStorageInfo(
connection_string=config.storage_azureconnectionstring,
bucket_name=config.storage_bucket,
)
else:
raise UnknownStorageBackend(f"Unknown storage backend '{config.storage_backend}'.")
return storage_info

View File

@ -77,5 +77,5 @@ class AzureStorage(Storage):
return zip(repeat(bucket_name), map(attrgetter("name"), blobs))
def get_azure_storage(config: Config):
def get_azure_storage_from_config(config: Config):
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=config.storage_azureconnectionstring))

View File

@ -67,7 +67,7 @@ class S3Storage(Storage):
return zip(repeat(bucket_name), map(attrgetter("object_name"), objs))
def get_s3_storage(config: Config):
def get_s3_storage_from_config(config: Config):
return S3Storage(
Minio(
secure=config.storage_secure_connection,

49
pyinfra/utils/cipher.py Normal file
View File

@ -0,0 +1,49 @@
import base64
import os
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
def build_aes_gcm_cipher(public_key, iv=None):
encoded_key = public_key.encode("utf-8")
kdf = PBKDF2HMAC(
algorithm=hashes.SHA1(),
length=16,
salt=iv,
iterations=65536,
)
private_key = kdf.derive(encoded_key)
return AESGCM(private_key)
def encrypt(public_key: str, plaintext: str, iv: int = None) -> str:
"""Encrypt a text with AES/GCS using a public key.
The byte-converted ciphertext consists of an unsigned 32-bit integer big-endian byteorder header i.e. the first 4
bytes, specifying the length of the following initialization vector (iv). The rest of the text contains the
encrypted message.
"""
iv = iv or os.urandom(12)
plaintext_bytes = plaintext.encode("utf-8")
cipher = build_aes_gcm_cipher(public_key, iv)
header = len(iv).to_bytes(length=4, byteorder="big")
encrypted = header + iv + cipher.encrypt(nonce=iv, data=plaintext_bytes, associated_data=None)
return base64.b64encode(encrypted).decode("utf-8")
def decrypt(public_key: str, ciphertext: str) -> str:
"""Decrypt an AES/GCS encrypted text with a public key.
The byte-converted ciphertext consists of an unsigned 32-bit integer big-endian byteorder header i.e. the first 4
bytes, specifying the length of the following initialization vector (iv). The rest of the text contains the
encrypted message.
"""
ciphertext_bytes = base64.b64decode(ciphertext)
header, rest = ciphertext_bytes[:4], ciphertext_bytes[4:]
iv_length = int.from_bytes(header, "big")
iv, ciphertext_bytes = rest[:iv_length], rest[iv_length:]
cipher = build_aes_gcm_cipher(public_key, iv)
decrypted_text = cipher.decrypt(nonce=iv, data=ciphertext_bytes, associated_data=None)
return decrypted_text.decode("utf-8")

5
pyinfra/utils/dict.py Normal file
View File

@ -0,0 +1,5 @@
from funcy import project
def safe_project(mapping, keys) -> dict:
return project(mapping, keys) if mapping else {}

View File

@ -17,6 +17,7 @@ testcontainers = "3.4.2"
docker-compose = "1.29.2"
funcy = "1.17"
prometheus-client = "^0.16.0"
pycryptodome = "^3.17"
[tool.poetry.group.dev.dependencies]
pytest = "^7.1.3"

View File

@ -7,7 +7,7 @@ import pika
from pyinfra.config import get_config
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
from pyinfra.storage.storages.s3 import get_s3_storage
from pyinfra.storage.storages.s3 import get_s3_storage_from_config
CONFIG = get_config()
logging.basicConfig()
@ -26,7 +26,7 @@ def upload_json_and_make_message_body():
object_name = f"{dossier_id}/{file_id}.{suffix}"
data = gzip.compress(json.dumps(content).encode("utf-8"))
storage = get_s3_storage(CONFIG)
storage = get_s3_storage_from_config(CONFIG)
if not storage.has_bucket(bucket):
storage.make_bucket(bucket)
storage.put_object(bucket, object_name, data)
@ -46,10 +46,10 @@ def main():
message = upload_json_and_make_message_body()
development_queue_manager.publish_request(message, pika.BasicProperties(headers={"x-tenant-id": "redaction"}))
development_queue_manager.publish_request(message, pika.BasicProperties(headers={"X-TENANT-ID": "redaction"}))
logger.info(f"Put {message} on {CONFIG.request_queue}")
storage = get_s3_storage(CONFIG)
storage = get_s3_storage_from_config(CONFIG)
for method_frame, properties, body in development_queue_manager._channel.consume(
queue=CONFIG.response_queue, inactivity_timeout=15
):

View File

@ -1,194 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'pprint.pprint'; 'pprint' is not a package",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [10], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mloader\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FullLoader\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpprint\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpprint\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpp\u001b[39;00m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pprint.pprint'; 'pprint' is not a package"
]
}
],
"source": [
"import pyinfra\n",
"import yaml\n",
"from yaml.loader import FullLoader\n",
"import pprint"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'logging': 0,\n",
" 'mock_analysis_endpoint': 'http://127.0.0.1:5000',\n",
" 'service': {'operations': {'classify': {'input': {'extension': 'cls_in.gz',\n",
" 'multi': True,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'cls_out.gz',\n",
" 'subdir': ''}},\n",
" 'default': {'input': {'extension': 'IN.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'OUT.gz',\n",
" 'subdir': ''}},\n",
" 'extract': {'input': {'extension': 'extr_in.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'gz',\n",
" 'subdir': 'extractions'}},\n",
" 'rotate': {'input': {'extension': 'rot_in.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'rot_out.gz',\n",
" 'subdir': ''}},\n",
" 'stream_pages': {'input': {'extension': 'pgs_in.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'pgs_out.gz',\n",
" 'subdir': 'pages'}},\n",
" 'upper': {'input': {'extension': 'up_in.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'up_out.gz',\n",
" 'subdir': ''}}},\n",
" 'response_formatter': 'identity'},\n",
" 'storage': {'aws': {'access_key': 'AKIA4QVP6D4LCDAGYGN2',\n",
" 'endpoint': 'https://s3.amazonaws.com',\n",
" 'region': '$STORAGE_REGION|\"eu-west-1\"',\n",
" 'secret_key': '8N6H1TUHTsbvW2qMAm7zZlJ63hMqjcXAsdN7TYED'},\n",
" 'azure': {'connection_string': 'DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net'},\n",
" 'bucket': 'pyinfra-test-bucket',\n",
" 'minio': {'access_key': 'root',\n",
" 'endpoint': 'http://127.0.0.1:9000',\n",
" 'region': None,\n",
" 'secret_key': 'password'}},\n",
" 'use_docker_fixture': 1,\n",
" 'webserver': {'host': '$SERVER_HOST|\"127.0.0.1\"',\n",
" 'mode': '$SERVER_MODE|production',\n",
" 'port': '$SERVER_PORT|5000'}}\n"
]
}
],
"source": [
"\n",
"# Open the file and load the file\n",
"with open('./tests/config.yml') as f:\n",
" data = yaml.load(f, Loader=FullLoader)\n",
" pprint.pprint(data)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"STORAGE_BACKEND\"] = \"azure\"\n",
"\n",
"# always the same\n",
"os.environ[\"STORAGE_BUCKET_NAME\"] = \"pyinfra-test-bucket\"\n",
"\n",
"# s3\n",
"os.environ[\"STORAGE_ENDPOINT\"] = \"https://s3.amazonaws.com\"\n",
"os.environ[\"STORAGE_KEY\"] = \"AKIA4QVP6D4LCDAGYGN2\"\n",
"os.environ[\"STORAGE_SECRET\"] = \"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\"\n",
"os.environ[\"STORAGE_REGION\"] = \"eu-west-1\"\n",
"\n",
"# aks\n",
"os.environ[\"STORAGE_AZURECONNECTIONSTRING\"] = \"DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net\""
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"ename": "Exception",
"evalue": "Unknown storage backend 'aks'.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [23], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m config \u001b[38;5;241m=\u001b[39m pyinfra\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mget_config()\n\u001b[0;32m----> 2\u001b[0m storage \u001b[38;5;241m=\u001b[39m \u001b[43mpyinfra\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstorage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_storage\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/dev/pyinfra/pyinfra/storage/storage.py:15\u001b[0m, in \u001b[0;36mget_storage\u001b[0;34m(config)\u001b[0m\n\u001b[1;32m 13\u001b[0m storage \u001b[39m=\u001b[39m get_azure_storage(config)\n\u001b[1;32m 14\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 15\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mException\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mUnknown storage backend \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mconfig\u001b[39m.\u001b[39mstorage_backend\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 17\u001b[0m \u001b[39mreturn\u001b[39;00m storage\n",
"\u001b[0;31mException\u001b[0m: Unknown storage backend 'aks'."
]
}
],
"source": [
"config = pyinfra.config.get_config()\n",
"storage = pyinfra.storage.get_storage(config)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"storage.has_bucket(config.storage_bucket)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.13 ('pyinfra-TboPpZ8z-py3.8')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "10d7419af5ea6dfec0078ebc9d6fa1a9383fe9894853f90dc7d29a81b3de2c78"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

29
tests/cipher_test.py Normal file
View File

@ -0,0 +1,29 @@
import pytest
from pyinfra.utils.cipher import decrypt, encrypt
@pytest.fixture
def ciphertext():
return "AAAADBRzag4/aAE2+rSekyI5phVZ1e0wwSaRkGQTLftPyVvq8vLYZzwxW48Wozc3/w=="
@pytest.fixture
def plaintext():
return "connectzionString"
@pytest.fixture
def public_key():
return "redaction"
class TestDecryption:
def test_decrypt_ciphertext(self, public_key, ciphertext, plaintext):
result = decrypt(public_key, ciphertext)
assert result == plaintext
def test_encrypt_plaintext(self, public_key, plaintext):
ciphertext = encrypt(public_key, plaintext)
result = decrypt(public_key, ciphertext)
assert plaintext == result

View File

@ -6,8 +6,7 @@ import pytest
import testcontainers.compose
from pyinfra.config import get_config
from pyinfra.queue.queue_manager import QueueManager
from pyinfra.storage import get_storage
from pyinfra.storage import get_storage_from_config
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@ -30,60 +29,35 @@ def docker_compose(sleep_seconds=30):
@pytest.fixture(scope="session")
def storage_config(client_name):
def test_storage_config(storage_backend, bucket_name, monitoring_enabled):
config = get_config()
config.storage_backend = client_name
config.storage_backend = storage_backend
config.storage_bucket = bucket_name
config.storage_azureconnectionstring = "DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net"
config.monitoring_enabled = monitoring_enabled
config.prometheus_metric_prefix = "test"
config.prometheus_port = 8080
config.prometheus_host = "0.0.0.0"
return config
@pytest.fixture(scope="session")
def processing_config(storage_config, monitoring_enabled):
storage_config.monitoring_enabled = monitoring_enabled
return storage_config
@pytest.fixture(scope="session")
def bucket_name(storage_config):
return storage_config.storage_bucket
@pytest.fixture(scope="session")
def storage(storage_config):
logger.debug("Setup for storage")
storage = get_storage(storage_config)
storage.make_bucket(storage_config.storage_bucket)
storage.clear_bucket(storage_config.storage_bucket)
yield storage
logger.debug("Teardown for storage")
try:
storage.clear_bucket(storage_config.storage_bucket)
except:
pass
@pytest.fixture(scope="session")
def queue_config(payload_processor_type):
def test_queue_config():
config = get_config()
# FIXME: It looks like rabbitmq_heartbeat has to be greater than rabbitmq_connection_sleep. If this is expected, the
# user should not be abele to insert non working values.
config.rabbitmq_heartbeat = config.rabbitmq_connection_sleep + 1
config.rabbitmq_connection_sleep = 2
config.rabbitmq_heartbeat = 4
return config
@pytest.fixture(scope="session")
def queue_manager(queue_config):
queue_manager = QueueManager(queue_config)
return queue_manager
@pytest.fixture
def request_payload():
def payload(x_tenant_id):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
return {
"dossierId": "test",
"fileId": "test",
"targetFileExtension": "json.gz",
"responseFileExtension": "json.gz",
**x_tenant_entry,
}
@ -93,3 +67,17 @@ def response_payload():
"dossierId": "test",
"fileId": "test",
}
@pytest.fixture(scope="session")
def storage(test_storage_config):
logger.debug("Setup for storage")
storage = get_storage_from_config(test_storage_config)
storage.make_bucket(test_storage_config.storage_bucket)
storage.clear_bucket(test_storage_config.storage_bucket)
yield storage
logger.debug("Teardown for storage")
try:
storage.clear_bucket(test_storage_config.storage_bucket)
except:
pass

48
tests/lru_test.py Normal file
View File

@ -0,0 +1,48 @@
from functools import lru_cache
import pytest
def func(callback):
return callback()
@pytest.fixture()
def fn(maxsize):
return lru_cache(maxsize)(func)
@pytest.fixture(params=[1, 2, 5])
def maxsize(request):
return request.param
class Callback:
def __init__(self, x):
self.initial_x = x
self.x = x
def __call__(self, *args, **kwargs):
self.x += 1
return self.x
def __hash__(self):
return hash(self.initial_x)
def test_adding_to_cache_within_maxsize_does_not_overwrite(fn, maxsize):
c = Callback(0)
for i in range(maxsize):
assert fn(c) == 1
assert fn(c) == 1
def test_adding_to_cache_more_than_maxsize_does_overwrite(fn, maxsize):
callbacks = [Callback(i) for i in range(maxsize)]
for i in range(maxsize):
assert fn(callbacks[i]) == i + 1
assert fn(Callback(maxsize)) == maxsize + 1
assert fn(callbacks[0]) == 2

View File

@ -4,43 +4,41 @@ import time
import pytest
import requests
from pyinfra.config import get_config
from pyinfra.payload_processing.monitor import get_monitor
from pyinfra.payload_processing.monitor import PrometheusMonitor
@pytest.fixture(scope="class")
def monitor_config():
config = get_config()
config.prometheus_metric_prefix = "monitor_test"
config.prometheus_port = 8000
return config
@pytest.fixture(scope="class")
def prometheus_monitor(monitor_config):
return get_monitor(monitor_config)
@pytest.fixture
def monitored_mock_function(prometheus_monitor):
def monitored_mock_function(metric_prefix, host, port):
def process(data=None):
time.sleep(2)
return ["result1", "result2", "result3"]
return prometheus_monitor(process)
monitor = PrometheusMonitor(metric_prefix, host, port)
return monitor(process)
@pytest.fixture
def metric_endpoint(host, port):
return f"http://{host}:{port}/prometheus"
@pytest.mark.parametrize("metric_prefix, host, port", [("test", "0.0.0.0", 8000)], scope="class")
class TestPrometheusMonitor:
def test_prometheus_endpoint_is_available(self, prometheus_monitor, monitor_config):
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus")
def test_prometheus_endpoint_is_available(self, metric_endpoint, monitored_mock_function):
resp = requests.get(metric_endpoint)
assert resp.status_code == 200
def test_processing_with_a_monitored_fn_increases_parameter_counter(self, monitored_mock_function, monitor_config):
def test_processing_with_a_monitored_fn_increases_parameter_counter(
self,
metric_endpoint,
metric_prefix,
monitored_mock_function,
):
monitored_mock_function(data=None)
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus")
pattern = re.compile(r".*monitor_test_processing_time_count (\d\.\d).*")
resp = requests.get(metric_endpoint)
pattern = re.compile(rf".*{metric_prefix}_processing_time_count (\d\.\d).*")
assert pattern.search(resp.text).group(1) == "1.0"
monitored_mock_function(data=None)
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus")
resp = requests.get(metric_endpoint)
assert pattern.search(resp.text).group(1) == "2.0"

View File

@ -0,0 +1,54 @@
import pytest
from pyinfra.payload_processing.payload import (
QueueMessagePayload,
QueueMessagePayloadParser,
)
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
@pytest.fixture
def expected_parsed_payload(x_tenant_id):
return QueueMessagePayload(
dossier_id="test",
file_id="test",
x_tenant_id=x_tenant_id,
target_file_extension="json.gz",
response_file_extension="json.gz",
target_file_type="json",
target_compression_type="gz",
response_file_type="json",
response_compression_type="gz",
target_file_name="test/test.json.gz",
response_file_name="test/test.json.gz",
processing_kwargs={},
)
@pytest.fixture
def file_extension_parser(allowed_file_types, allowed_compression_types):
return make_file_extension_parser(allowed_file_types, allowed_compression_types)
@pytest.fixture
def payload_parser(file_extension_parser):
return QueueMessagePayloadParser(file_extension_parser)
@pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])])
class TestPayload:
@pytest.mark.parametrize("x_tenant_id", [None, "klaus"])
def test_payload_is_parsed_correctly(self, payload_parser, payload, expected_parsed_payload):
payload = payload_parser(payload)
assert payload == expected_parsed_payload
@pytest.mark.parametrize(
"extension,expected",
[
("json.gz", ("json", "gz")),
("json", ("json", None)),
("prefix.json.gz", ("json", "gz")),
],
)
def test_parse_file_extension(self, file_extension_parser, extension, expected):
assert file_extension_parser(extension) == expected

View File

@ -8,14 +8,6 @@ import requests
from pyinfra.payload_processing.processor import make_payload_processor
@pytest.fixture(scope="session")
def file_processor_mock():
def inner(json_file: dict):
return [json_file]
return inner
@pytest.fixture
def target_file():
contents = {"numberOfPages": 10, "content1": "value1", "content2": "value2"}
@ -23,54 +15,60 @@ def target_file():
@pytest.fixture
def file_names(request_payload):
def file_names(payload):
dossier_id, file_id, target_suffix, response_suffix = itemgetter(
"dossierId",
"fileId",
"targetFileExtension",
"responseFileExtension",
)(request_payload)
)(payload)
return f"{dossier_id}/{file_id}.{target_suffix}", f"{dossier_id}/{file_id}.{response_suffix}"
@pytest.fixture(scope="session")
def payload_processor(file_processor_mock, processing_config):
yield make_payload_processor(file_processor_mock, processing_config)
def payload_processor(test_storage_config):
def file_processor_mock(json_file: dict):
return [json_file]
yield make_payload_processor(file_processor_mock, test_storage_config)
@pytest.mark.parametrize("client_name", ["s3"], scope="session")
@pytest.mark.parametrize("storage_backend", ["s3"], scope="session")
@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session")
@pytest.mark.parametrize("monitoring_enabled", [True, False], scope="session")
@pytest.mark.parametrize("x_tenant_id", [None])
class TestPayloadProcessor:
def test_payload_processor_yields_correct_response_and_uploads_result(
self,
payload_processor,
storage,
bucket_name,
request_payload,
payload,
response_payload,
target_file,
file_names,
):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, file_names[0], target_file)
response = payload_processor(request_payload)
response = payload_processor(payload)
assert response == response_payload
data_received = storage.get_object(bucket_name, file_names[1])
assert json.loads((gzip.decompress(data_received)).decode("utf-8")) == {
**request_payload,
**payload,
"data": [json.loads(gzip.decompress(target_file).decode("utf-8"))],
}
def test_catching_of_processing_failure(self, payload_processor, storage, bucket_name, request_payload):
def test_catching_of_processing_failure(self, payload_processor, storage, bucket_name, payload):
storage.clear_bucket(bucket_name)
with pytest.raises(Exception):
payload_processor(request_payload)
payload_processor(payload)
def test_prometheus_endpoint_is_available(self, processing_config):
resp = requests.get(
f"http://{processing_config.prometheus_host}:{processing_config.prometheus_port}/prometheus"
)
assert resp.status_code == 200
def test_prometheus_endpoint_is_available(self, test_storage_config, monitoring_enabled, storage_backend, x_tenant_id):
if monitoring_enabled:
resp = requests.get(
f"http://{test_storage_config.prometheus_host}:{test_storage_config.prometheus_port}/prometheus"
)
assert resp.status_code == 200

View File

@ -1,44 +0,0 @@
import pytest
from pyinfra.config import get_config
from pyinfra.payload_processing.payload import (
QueueMessagePayload,
get_queue_message_payload_parser,
)
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
@pytest.fixture(scope="session")
def payload_config():
return get_config()
class TestPayload:
def test_payload_is_parsed_correctly(self, request_payload, payload_config):
parse_payload = get_queue_message_payload_parser(payload_config)
payload = parse_payload(request_payload)
assert payload == QueueMessagePayload(
dossier_id="test",
file_id="test",
target_file_extension="json.gz",
response_file_extension="json.gz",
target_file_type="json",
target_compression_type="gz",
response_file_type="json",
response_compression_type="gz",
target_file_name="test/test.json.gz",
response_file_name="test/test.json.gz",
)
@pytest.mark.parametrize(
"extension,expected",
[
("json.gz", ("json", "gz")),
("json", ("json", None)),
("prefix.json.gz", ("json", "gz")),
],
)
@pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])])
def test_parse_file_extension(self, extension, expected, allowed_file_types, allowed_compression_types):
parse = make_file_extension_parser(allowed_file_types, allowed_compression_types)
assert parse(extension) == expected

View File

@ -8,15 +8,16 @@ import pika.exceptions
import pytest
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
from pyinfra.queue.queue_manager import QueueManager
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@pytest.fixture(scope="session")
def development_queue_manager(queue_config):
queue_config.rabbitmq_heartbeat = 7200
development_queue_manager = DevelopmentQueueManager(queue_config)
def development_queue_manager(test_queue_config):
test_queue_config.rabbitmq_heartbeat = 7200
development_queue_manager = DevelopmentQueueManager(test_queue_config)
yield development_queue_manager
logger.info("Tearing down development queue manager...")
try:
@ -26,10 +27,10 @@ def development_queue_manager(queue_config):
@pytest.fixture(scope="session")
def payload_processing_time(queue_config, offset=5):
def payload_processing_time(test_queue_config, offset=5):
# FIXME: this implicitly tests the heartbeat when running the end-to-end test. There should be another way to test
# this explicitly.
return queue_config.rabbitmq_heartbeat + offset
return test_queue_config.rabbitmq_heartbeat + offset
@pytest.fixture(scope="session")
@ -48,10 +49,11 @@ def payload_processor(response_payload, payload_processing_time, payload_process
@pytest.fixture(scope="session", autouse=True)
def start_queue_consumer(queue_manager, payload_processor, sleep_seconds=5):
def start_queue_consumer(test_queue_config, payload_processor, sleep_seconds=5):
def consume_queue():
queue_manager.start_consuming(payload_processor)
queue_manager = QueueManager(test_queue_config)
p = Process(target=consume_queue)
p.start()
logger.info(f"Setting up consumer, waiting for {sleep_seconds}...")
@ -65,39 +67,40 @@ def start_queue_consumer(queue_manager, payload_processor, sleep_seconds=5):
def message_properties(message_headers):
if not message_headers:
return pika.BasicProperties(headers=None)
elif message_headers == "x-tenant-id":
return pika.BasicProperties(headers={"x-tenant-id": "redaction"})
elif message_headers == "X-TENANT-ID":
return pika.BasicProperties(headers={"X-TENANT-ID": "redaction"})
else:
raise Exception(f"Invalid {message_headers=}.")
@pytest.mark.parametrize("x_tenant_id", [None])
class TestQueueManager:
# FIXME: All tests here are wonky. This is due to the implementation of running the process-blocking queue_manager
# in a subprocess. It is then very hard to interact directly with the subprocess. If you have a better idea, please
# refactor; the tests here are insufficient to ensure the functionality of the queue manager!
@pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session")
def test_message_processing_does_not_block_heartbeat(
self, development_queue_manager, request_payload, response_payload, payload_processing_time
self, development_queue_manager, payload, response_payload, payload_processing_time
):
development_queue_manager.clear_queues()
development_queue_manager.publish_request(request_payload)
development_queue_manager.publish_request(payload)
time.sleep(payload_processing_time + 10)
_, _, body = development_queue_manager.get_response()
result = json.loads(body)
assert result == response_payload
@pytest.mark.parametrize("message_headers", [None, "x-tenant-id"])
@pytest.mark.parametrize("message_headers", [None, "X-TENANT-ID"])
@pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session")
def test_queue_manager_forwards_message_headers(
self,
development_queue_manager,
request_payload,
payload,
response_payload,
payload_processing_time,
message_properties,
):
development_queue_manager.clear_queues()
development_queue_manager.publish_request(request_payload, message_properties)
development_queue_manager.publish_request(payload, message_properties)
time.sleep(payload_processing_time + 10)
_, properties, _ = development_queue_manager.get_response()
assert properties.headers == message_properties.headers
@ -109,12 +112,12 @@ class TestQueueManager:
def test_failed_message_processing_is_handled(
self,
development_queue_manager,
request_payload,
payload,
response_payload,
payload_processing_time,
):
development_queue_manager.clear_queues()
development_queue_manager.publish_request(request_payload)
development_queue_manager.publish_request(payload)
time.sleep(payload_processing_time + 10)
_, _, body = development_queue_manager.get_response()
assert not body

View File

@ -6,7 +6,9 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("client_name", ["azure", "s3"], scope="session")
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="session")
@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session")
@pytest.mark.parametrize("monitoring_enabled", [False], scope="session")
class TestStorage:
def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name):
storage.clear_bucket(bucket_name)