diff --git a/poetry.lock b/poetry.lock index dc62908..3379718 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1360,6 +1360,49 @@ files = [ {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, ] +[[package]] +name = "pycryptodome" +version = "3.17" +description = "Cryptographic library for Python" +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "pycryptodome-3.17-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:2c5631204ebcc7ae33d11c43037b2dafe25e2ab9c1de6448eb6502ac69c19a56"}, + {file = "pycryptodome-3.17-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:04779cc588ad8f13c80a060b0b1c9d1c203d051d8a43879117fe6b8aaf1cd3fa"}, + {file = "pycryptodome-3.17-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:f812d58c5af06d939b2baccdda614a3ffd80531a26e5faca2c9f8b1770b2b7af"}, + {file = "pycryptodome-3.17-cp27-cp27m-manylinux2014_aarch64.whl", hash = "sha256:9453b4e21e752df8737fdffac619e93c9f0ec55ead9a45df782055eb95ef37d9"}, + {file = "pycryptodome-3.17-cp27-cp27m-musllinux_1_1_aarch64.whl", hash = "sha256:121d61663267f73692e8bde5ec0d23c9146465a0d75cad75c34f75c752527b01"}, + {file = "pycryptodome-3.17-cp27-cp27m-win32.whl", hash = "sha256:ba2d4fcb844c6ba5df4bbfee9352ad5352c5ae939ac450e06cdceff653280450"}, + {file = "pycryptodome-3.17-cp27-cp27m-win_amd64.whl", hash = "sha256:87e2ca3aa557781447428c4b6c8c937f10ff215202ab40ece5c13a82555c10d6"}, + {file = "pycryptodome-3.17-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:f44c0d28716d950135ff21505f2c764498eda9d8806b7c78764165848aa419bc"}, + {file = "pycryptodome-3.17-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:5a790bc045003d89d42e3b9cb3cc938c8561a57a88aaa5691512e8540d1ae79c"}, + {file = "pycryptodome-3.17-cp27-cp27mu-manylinux2014_aarch64.whl", hash = "sha256:d086d46774e27b280e4cece8ab3d87299cf0d39063f00f1e9290d096adc5662a"}, + {file = "pycryptodome-3.17-cp27-cp27mu-musllinux_1_1_aarch64.whl", hash = "sha256:5587803d5b66dfd99e7caa31ed91fba0fdee3661c5d93684028ad6653fce725f"}, + {file = "pycryptodome-3.17-cp35-abi3-macosx_10_9_universal2.whl", hash = "sha256:e7debd9c439e7b84f53be3cf4ba8b75b3d0b6e6015212355d6daf44ac672e210"}, + {file = "pycryptodome-3.17-cp35-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ca1ceb6303be1282148f04ac21cebeebdb4152590842159877778f9cf1634f09"}, + {file = "pycryptodome-3.17-cp35-abi3-manylinux2014_aarch64.whl", hash = "sha256:dc22cc00f804485a3c2a7e2010d9f14a705555f67020eb083e833cabd5bd82e4"}, + {file = "pycryptodome-3.17-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80ea8333b6a5f2d9e856ff2293dba2e3e661197f90bf0f4d5a82a0a6bc83a626"}, + {file = "pycryptodome-3.17-cp35-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c133f6721fba313722a018392a91e3c69d3706ae723484841752559e71d69dc6"}, + {file = "pycryptodome-3.17-cp35-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:333306eaea01fde50a73c4619e25631e56c4c61bd0fb0a2346479e67e3d3a820"}, + {file = "pycryptodome-3.17-cp35-abi3-musllinux_1_1_i686.whl", hash = "sha256:1a30f51b990994491cec2d7d237924e5b6bd0d445da9337d77de384ad7f254f9"}, + {file = "pycryptodome-3.17-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:909e36a43fe4a8a3163e9c7fc103867825d14a2ecb852a63d3905250b308a4e5"}, + {file = "pycryptodome-3.17-cp35-abi3-win32.whl", hash = "sha256:a3228728a3808bc9f18c1797ec1179a0efb5068c817b2ffcf6bcd012494dffb2"}, + {file = "pycryptodome-3.17-cp35-abi3-win_amd64.whl", hash = "sha256:9ec565e89a6b400eca814f28d78a9ef3f15aea1df74d95b28b7720739b28f37f"}, + {file = "pycryptodome-3.17-pp27-pypy_73-macosx_10_9_x86_64.whl", hash = "sha256:e1819b67bcf6ca48341e9b03c2e45b1c891fa8eb1a8458482d14c2805c9616f2"}, + {file = "pycryptodome-3.17-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:f8e550caf52472ae9126953415e4fc554ab53049a5691c45b8816895c632e4d7"}, + {file = "pycryptodome-3.17-pp27-pypy_73-win32.whl", hash = "sha256:afbcdb0eda20a0e1d44e3a1ad6d4ec3c959210f4b48cabc0e387a282f4c7deb8"}, + {file = "pycryptodome-3.17-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a74f45aee8c5cc4d533e585e0e596e9f78521e1543a302870a27b0ae2106381e"}, + {file = "pycryptodome-3.17-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38bbd6717eac084408b4094174c0805bdbaba1f57fc250fd0309ae5ec9ed7e09"}, + {file = "pycryptodome-3.17-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f68d6c8ea2974a571cacb7014dbaada21063a0375318d88ac1f9300bc81e93c3"}, + {file = "pycryptodome-3.17-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:8198f2b04c39d817b206ebe0db25a6653bb5f463c2319d6f6d9a80d012ac1e37"}, + {file = "pycryptodome-3.17-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3a232474cd89d3f51e4295abe248a8b95d0332d153bf46444e415409070aae1e"}, + {file = "pycryptodome-3.17-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4992ec965606054e8326e83db1c8654f0549cdb26fce1898dc1a20bc7684ec1c"}, + {file = "pycryptodome-3.17-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:53068e33c74f3b93a8158dacaa5d0f82d254a81b1002e0cd342be89fcb3433eb"}, + {file = "pycryptodome-3.17-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:74794a2e2896cd0cf56fdc9db61ef755fa812b4a4900fa46c49045663a92b8d0"}, + {file = "pycryptodome-3.17.tar.gz", hash = "sha256:bce2e2d8e82fcf972005652371a3e8731956a0c1fbb719cc897943b3695ad91b"}, +] + [[package]] name = "pygments" version = "2.14.0" @@ -2041,4 +2084,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "~3.8" -content-hash = "d5017298c5fddc7d919a319fec7e050e8f7603c0c11c08d5c6c3115e751083e2" +content-hash = "43b21cbd83ef95d9d28a72856f9b2d78ad4f0c8c740af58bf8b40d1c61dba50d" diff --git a/pyinfra/config.py b/pyinfra/config.py index 2039c93..0afb57c 100644 --- a/pyinfra/config.py +++ b/pyinfra/config.py @@ -92,6 +92,10 @@ class Config: self.allowed_file_types = ["json", "pdf"] self.allowed_compression_types = ["gz"] + # config for x-tenant-endpoint to receive storage connection information per tenant + self.persistence_service_public_key = "redaction" + self.persistence_service_tenant_endpoint = "http://persistence-service-v1:8080/internal-api/tenants" + # Value to see if we should write a consumer token to a file self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False") diff --git a/pyinfra/exception.py b/pyinfra/exception.py index 74a9dcb..b8d35de 100644 --- a/pyinfra/exception.py +++ b/pyinfra/exception.py @@ -1,2 +1,5 @@ class ProcessingFailure(RuntimeError): pass + +class UnknownStorageBackend(Exception): + pass \ No newline at end of file diff --git a/pyinfra/payload_processing/monitor.py b/pyinfra/payload_processing/monitor.py index b7f2530..a507bd9 100644 --- a/pyinfra/payload_processing/monitor.py +++ b/pyinfra/payload_processing/monitor.py @@ -12,7 +12,7 @@ logger = logging.getLogger() class PrometheusMonitor: - def __init__(self, prefix: str, port=8080, host="127.0.0.1"): + def __init__(self, prefix: str, host: str, port: int): """Register the monitoring metrics and start a webserver where they can be scraped at the endpoint http://{host}:{port}/prometheus @@ -23,12 +23,9 @@ class PrometheusMonitor: self.registry = CollectorRegistry() self.entity_processing_time_sum = Summary( - f"{prefix}_processing_time", - "Summed up average processing time per entity observed", + f"{prefix}_processing_time", "Summed up average processing time per entity observed", registry=self.registry ) - self.registry.register(self.entity_processing_time_sum) - start_http_server(port, host, self.registry) def __call__(self, process_fn: Callable) -> Callable: @@ -58,8 +55,8 @@ class PrometheusMonitor: return inner -def get_monitor(config: Config) -> Callable: +def get_monitor_from_config(config: Config) -> Callable: if config.monitoring_enabled: - return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_port", "prometheus_host")(config)) + return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_host", "prometheus_port")(config)) else: return identity diff --git a/pyinfra/payload_processing/payload.py b/pyinfra/payload_processing/payload.py index 48d4bd2..ef5f394 100644 --- a/pyinfra/payload_processing/payload.py +++ b/pyinfra/payload_processing/payload.py @@ -3,6 +3,8 @@ from itertools import chain from operator import itemgetter from typing import Union, Sized +from funcy import project + from pyinfra.config import Config from pyinfra.utils.file_extension_parsing import make_file_extension_parser @@ -11,6 +13,8 @@ from pyinfra.utils.file_extension_parsing import make_file_extension_parser class QueueMessagePayload: dossier_id: str file_id: str + x_tenant_id: Union[str, None] + target_file_extension: str response_file_extension: str @@ -22,10 +26,13 @@ class QueueMessagePayload: target_file_name: str response_file_name: str + processing_kwargs: dict + class QueueMessagePayloadParser: - def __init__(self, file_extension_parser): + def __init__(self, file_extension_parser, allowed_processing_args=("operation",)): self.parse_file_extensions = file_extension_parser + self.allowed_args = allowed_processing_args def __call__(self, payload: dict) -> QueueMessagePayload: """Translate the queue message payload to the internal QueueMessagePayload object.""" @@ -35,6 +42,7 @@ class QueueMessagePayloadParser: dossier_id, file_id, target_file_extension, response_file_extension = itemgetter( "dossierId", "fileId", "targetFileExtension", "responseFileExtension" )(payload) + x_tenant_id = payload.get("X-TENANT-ID") target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable( map(self.parse_file_extensions, [target_file_extension, response_file_extension]) @@ -43,9 +51,12 @@ class QueueMessagePayloadParser: target_file_name = f"{dossier_id}/{file_id}.{target_file_extension}" response_file_name = f"{dossier_id}/{file_id}.{response_file_extension}" + processing_kwargs = project(payload, self.allowed_args) + return QueueMessagePayload( dossier_id=dossier_id, file_id=file_id, + x_tenant_id=x_tenant_id, target_file_extension=target_file_extension, response_file_extension=response_file_extension, target_file_type=target_file_type, @@ -54,6 +65,7 @@ class QueueMessagePayloadParser: response_compression_type=response_compression_type, target_file_name=target_file_name, response_file_name=response_file_name, + processing_kwargs=processing_kwargs, ) diff --git a/pyinfra/payload_processing/processor.py b/pyinfra/payload_processing/processor.py index ebfcf81..37f31a1 100644 --- a/pyinfra/payload_processing/processor.py +++ b/pyinfra/payload_processing/processor.py @@ -1,21 +1,23 @@ import logging from dataclasses import asdict from functools import partial -from typing import Callable, Union, List - -from funcy import compose +from typing import Callable, List from pyinfra.config import get_config, Config -from pyinfra.payload_processing.monitor import get_monitor +from pyinfra.payload_processing.monitor import get_monitor_from_config from pyinfra.payload_processing.payload import ( QueueMessagePayloadParser, get_queue_message_payload_parser, QueueMessagePayloadFormatter, get_queue_message_payload_formatter, ) -from pyinfra.storage import get_storage from pyinfra.storage.storage import make_downloader, make_uploader -from pyinfra.storage.storages.interface import Storage +from pyinfra.storage.storage_info import ( + get_storage_info_from_config, + get_storage_info_from_endpoint, + StorageInfo, + get_storage_from_storage_info, +) logger = logging.getLogger() logger.setLevel(get_config().logging_level_root) @@ -24,8 +26,8 @@ logger.setLevel(get_config().logging_level_root) class PayloadProcessor: def __init__( self, - storage: Storage, - bucket: str, + default_storage_info: StorageInfo, + get_storage_info_from_tenant_id, payload_parser: QueueMessagePayloadParser, payload_formatter: QueueMessagePayloadFormatter, data_processor: Callable, @@ -33,21 +35,22 @@ class PayloadProcessor: """Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps. Args: - storage: The storage to use for downloading and uploading files - bucket: The bucket to use for downloading and uploading files + default_storage_info: The default storage info used to create the storage connection. This is only used if + x_tenant_id is not provided in the queue payload. + get_storage_info_from_tenant_id: Callable to acquire storage info from a given tenant id. payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object payload_formatter: Formatter for the storage upload result and the queue message response body data_processor: The analysis function to be called with the downloaded file - NOTE: the result of the analysis function has to be Sized, e.g. a dict or a list to be able to upload it - and to be able to monitor the processing time. + NOTE: The result of the analysis function has to be an instance of `Sized`, e.g. a dict or a list to be + able to upload it and to be able to monitor the processing time. """ self.parse_payload = payload_parser self.format_result_for_storage = payload_formatter.format_service_processing_result_for_storage self.format_to_queue_message_response_body = payload_formatter.format_to_queue_message_response_body self.process_data = data_processor - self.partial_download_fn = partial(make_downloader, storage, bucket) - self.partial_upload_fn = partial(make_uploader, storage, bucket) + self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id + self.default_storage_info = default_storage_info def __call__(self, queue_message_payload: dict) -> dict: """Processes a queue message payload. @@ -71,29 +74,58 @@ class PayloadProcessor: payload = self.parse_payload(queue_message_payload) logger.info(f"Processing {asdict(payload)} ...") - download_file_to_process = self.partial_download_fn(payload.target_file_type, payload.target_compression_type) - upload_processing_result = self.partial_upload_fn(payload.response_file_type, payload.response_compression_type) + storage_info = self._get_storage_info(payload.x_tenant_id) + storage = get_storage_from_storage_info(storage_info) + bucket = storage_info.bucket_name + + download_file_to_process = make_downloader( + storage, bucket, payload.target_file_type, payload.target_compression_type + ) + upload_processing_result = make_uploader( + storage, bucket, payload.response_file_type, payload.response_compression_type + ) format_result_for_storage = partial(self.format_result_for_storage, payload) - processing_pipeline = compose(format_result_for_storage, self.process_data, download_file_to_process) + data = download_file_to_process(payload.target_file_name) + result: List[dict] = self.process_data(data, **payload.processing_kwargs) + formatted_result = format_result_for_storage(result) - result: List[dict] = processing_pipeline(payload.target_file_name) - - upload_processing_result(payload.response_file_name, result) + upload_processing_result(payload.response_file_name, formatted_result) return self.format_to_queue_message_response_body(payload) + def _get_storage_info(self, x_tenant_id=None): + if x_tenant_id: + storage_info = self.get_storage_info_from_tenant_id(x_tenant_id) + logger.info(f"Received {storage_info.__class__.__name__} for {x_tenant_id} from endpoint.") + logger.debug(f"{asdict(storage_info)}") + else: + storage_info = self.default_storage_info + logger.info(f"Using local default {storage_info.__class__.__name__} for {x_tenant_id}.") + logger.debug(f"{asdict(storage_info)}") + return storage_info -def make_payload_processor(data_processor: Callable, config: Union[None, Config] = None) -> PayloadProcessor: + +def make_payload_processor(data_processor: Callable, config: Config = None) -> PayloadProcessor: """Produces payload processor for queue manager.""" config = config or get_config() - bucket: str = config.storage_bucket - storage: Storage = get_storage(config) - monitor = get_monitor(config) + default_storage_info: StorageInfo = get_storage_info_from_config(config) + get_storage_info_from_tenant_id = partial( + get_storage_info_from_endpoint, + config.persistence_service_public_key, + config.persistence_service_tenant_endpoint, + ) + monitor = get_monitor_from_config(config) payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config) payload_formatter: QueueMessagePayloadFormatter = get_queue_message_payload_formatter() data_processor = monitor(data_processor) - return PayloadProcessor(storage, bucket, payload_parser, payload_formatter, data_processor) + return PayloadProcessor( + default_storage_info, + get_storage_info_from_tenant_id, + payload_parser, + payload_formatter, + data_processor, + ) diff --git a/pyinfra/queue/queue_manager.py b/pyinfra/queue/queue_manager.py index ec241a0..f5ffb98 100644 --- a/pyinfra/queue/queue_manager.py +++ b/pyinfra/queue/queue_manager.py @@ -12,6 +12,7 @@ from pika.adapters.blocking_connection import BlockingChannel from pyinfra.config import Config from pyinfra.exception import ProcessingFailure from pyinfra.payload_processing.processor import PayloadProcessor +from pyinfra.utils.dict import safe_project CONFIG = Config() @@ -164,8 +165,8 @@ class QueueManager: except Exception as err: raise ProcessingFailure("QueueMessagePayload processing failed") from err - def acknowledge_message_and_publish_response(frame, properties, response_body): - response_properties = pika.BasicProperties(headers=properties.headers) if properties.headers else None + def acknowledge_message_and_publish_response(frame, headers, response_body): + response_properties = pika.BasicProperties(headers=headers) if headers else None self._channel.basic_publish("", self._output_queue, json.dumps(response_body).encode(), response_properties) self.logger.info( "Result published, acknowledging incoming message with delivery_tag %s", @@ -190,12 +191,15 @@ class QueueManager: try: self.logger.debug("Processing (%s, %s, %s)", frame, properties, body) - processing_result = process_message_body_and_await_result(json.loads(body)) + filtered_message_headers = safe_project(properties.headers, ["X-TENANT-ID"]) # TODO: parametrize key? + message_body = {**json.loads(body), **filtered_message_headers} + + processing_result = process_message_body_and_await_result(message_body) self.logger.info( "Processed message with delivery_tag %s, publishing result to result-queue", frame.delivery_tag, ) - acknowledge_message_and_publish_response(frame, properties, processing_result) + acknowledge_message_and_publish_response(frame, filtered_message_headers, processing_result) except ProcessingFailure: self.logger.info( diff --git a/pyinfra/storage/__init__.py b/pyinfra/storage/__init__.py index f5d004f..dccdcda 100644 --- a/pyinfra/storage/__init__.py +++ b/pyinfra/storage/__init__.py @@ -1,3 +1,3 @@ -from pyinfra.storage.storage import get_storage +from pyinfra.storage.storage import get_storage_from_config -__all__ = ["get_storage"] +__all__ = ["get_storage_from_config"] diff --git a/pyinfra/storage/storage.py b/pyinfra/storage/storage.py index 40128b8..bd849d8 100644 --- a/pyinfra/storage/storage.py +++ b/pyinfra/storage/storage.py @@ -4,21 +4,16 @@ from typing import Callable from funcy import compose from pyinfra.config import Config -from pyinfra.storage.storages.azure import get_azure_storage -from pyinfra.storage.storages.s3 import get_s3_storage +from pyinfra.storage.storage_info import get_storage_info_from_config, get_storage_from_storage_info from pyinfra.storage.storages.interface import Storage from pyinfra.utils.compressing import get_decompressor, get_compressor from pyinfra.utils.encoding import get_decoder, get_encoder -def get_storage(config: Config) -> Storage: +def get_storage_from_config(config: Config) -> Storage: - if config.storage_backend == "s3": - storage = get_s3_storage(config) - elif config.storage_backend == "azure": - storage = get_azure_storage(config) - else: - raise Exception(f"Unknown storage backend '{config.storage_backend}'.") + storage_info = get_storage_info_from_config(config) + storage = get_storage_from_storage_info(storage_info) return storage diff --git a/pyinfra/storage/storage_info.py b/pyinfra/storage/storage_info.py new file mode 100644 index 0000000..aefa15b --- /dev/null +++ b/pyinfra/storage/storage_info.py @@ -0,0 +1,125 @@ +from dataclasses import dataclass + +import requests +from azure.storage.blob import BlobServiceClient +from minio import Minio + +from pyinfra.config import Config +from pyinfra.exception import UnknownStorageBackend +from pyinfra.storage.storages.azure import AzureStorage +from pyinfra.storage.storages.interface import Storage +from pyinfra.storage.storages.s3 import S3Storage +from pyinfra.utils.cipher import decrypt +from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint + + +@dataclass(frozen=True) +class StorageInfo: + bucket_name: str + + +@dataclass(frozen=True) +class AzureStorageInfo(StorageInfo): + connection_string: str + + def __hash__(self): + return hash(self.connection_string) + + def __eq__(self, other): + if not isinstance(other, AzureStorageInfo): + return False + return self.connection_string == other.connection_string + + +@dataclass(frozen=True) +class S3StorageInfo(StorageInfo): + secure: bool + endpoint: str + access_key: str + secret_key: str + region: str + + def __hash__(self): + return hash((self.secure, self.endpoint, self.access_key, self.secret_key, self.region)) + + def __eq__(self, other): + if not isinstance(other, S3StorageInfo): + return False + return ( + self.secure == other.secure + and self.endpoint == other.endpoint + and self.access_key == other.access_key + and self.secret_key == other.secret_key + and self.region == other.region + ) + + +def get_storage_from_storage_info(storage_info: StorageInfo) -> Storage: + if isinstance(storage_info, AzureStorageInfo): + return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string)) + elif isinstance(storage_info, S3StorageInfo): + return S3Storage( + Minio( + secure=storage_info.secure, + endpoint=storage_info.endpoint, + access_key=storage_info.access_key, + secret_key=storage_info.secret_key, + region=storage_info.region, + ) + ) + else: + raise UnknownStorageBackend() + + +def get_storage_info_from_endpoint(public_key: str, endpoint: str, x_tenant_id: str) -> StorageInfo: + resp = requests.get(f"{endpoint}/{x_tenant_id}").json() + + maybe_azure = resp.get("azureStorageConnection") + maybe_s3 = resp.get("s3StorageConnection") + assert not (maybe_azure and maybe_s3) + + if maybe_azure: + connection_string = decrypt(public_key, maybe_azure["connectionString"]) + storage_info = AzureStorageInfo( + connection_string=connection_string, + bucket_name=maybe_azure["containerName"], + ) + elif maybe_s3: + secure, endpoint = validate_and_parse_s3_endpoint(maybe_s3["endpoint"]) + secret = decrypt(public_key, maybe_s3["secret"]) + + storage_info = S3StorageInfo( + secure=secure, + endpoint=endpoint, + access_key=maybe_s3["key"], + secret_key=secret, + region=maybe_s3["region"], + bucket_name=maybe_s3["bucketName"], + ) + else: + raise UnknownStorageBackend() + + return storage_info + + +def get_storage_info_from_config(config: Config) -> StorageInfo: + if config.storage_backend == "s3": + storage_info = S3StorageInfo( + secure=config.storage_secure_connection, + endpoint=config.storage_endpoint, + access_key=config.storage_key, + secret_key=config.storage_secret, + region=config.storage_region, + bucket_name=config.storage_bucket, + ) + + elif config.storage_backend == "azure": + storage_info = AzureStorageInfo( + connection_string=config.storage_azureconnectionstring, + bucket_name=config.storage_bucket, + ) + + else: + raise UnknownStorageBackend(f"Unknown storage backend '{config.storage_backend}'.") + + return storage_info diff --git a/pyinfra/storage/storages/azure.py b/pyinfra/storage/storages/azure.py index aaa9ba2..f6091a4 100644 --- a/pyinfra/storage/storages/azure.py +++ b/pyinfra/storage/storages/azure.py @@ -77,5 +77,5 @@ class AzureStorage(Storage): return zip(repeat(bucket_name), map(attrgetter("name"), blobs)) -def get_azure_storage(config: Config): +def get_azure_storage_from_config(config: Config): return AzureStorage(BlobServiceClient.from_connection_string(conn_str=config.storage_azureconnectionstring)) diff --git a/pyinfra/storage/storages/s3.py b/pyinfra/storage/storages/s3.py index 5763353..3eafeff 100644 --- a/pyinfra/storage/storages/s3.py +++ b/pyinfra/storage/storages/s3.py @@ -67,7 +67,7 @@ class S3Storage(Storage): return zip(repeat(bucket_name), map(attrgetter("object_name"), objs)) -def get_s3_storage(config: Config): +def get_s3_storage_from_config(config: Config): return S3Storage( Minio( secure=config.storage_secure_connection, diff --git a/pyinfra/utils/cipher.py b/pyinfra/utils/cipher.py new file mode 100644 index 0000000..6eee2ad --- /dev/null +++ b/pyinfra/utils/cipher.py @@ -0,0 +1,49 @@ +import base64 +import os + +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + + +def build_aes_gcm_cipher(public_key, iv=None): + encoded_key = public_key.encode("utf-8") + kdf = PBKDF2HMAC( + algorithm=hashes.SHA1(), + length=16, + salt=iv, + iterations=65536, + ) + private_key = kdf.derive(encoded_key) + return AESGCM(private_key) + + +def encrypt(public_key: str, plaintext: str, iv: int = None) -> str: + """Encrypt a text with AES/GCS using a public key. + + The byte-converted ciphertext consists of an unsigned 32-bit integer big-endian byteorder header i.e. the first 4 + bytes, specifying the length of the following initialization vector (iv). The rest of the text contains the + encrypted message. + """ + iv = iv or os.urandom(12) + plaintext_bytes = plaintext.encode("utf-8") + cipher = build_aes_gcm_cipher(public_key, iv) + header = len(iv).to_bytes(length=4, byteorder="big") + encrypted = header + iv + cipher.encrypt(nonce=iv, data=plaintext_bytes, associated_data=None) + return base64.b64encode(encrypted).decode("utf-8") + + +def decrypt(public_key: str, ciphertext: str) -> str: + """Decrypt an AES/GCS encrypted text with a public key. + + The byte-converted ciphertext consists of an unsigned 32-bit integer big-endian byteorder header i.e. the first 4 + bytes, specifying the length of the following initialization vector (iv). The rest of the text contains the + encrypted message. + """ + ciphertext_bytes = base64.b64decode(ciphertext) + header, rest = ciphertext_bytes[:4], ciphertext_bytes[4:] + iv_length = int.from_bytes(header, "big") + iv, ciphertext_bytes = rest[:iv_length], rest[iv_length:] + cipher = build_aes_gcm_cipher(public_key, iv) + decrypted_text = cipher.decrypt(nonce=iv, data=ciphertext_bytes, associated_data=None) + return decrypted_text.decode("utf-8") diff --git a/pyinfra/utils/dict.py b/pyinfra/utils/dict.py new file mode 100644 index 0000000..a732a6d --- /dev/null +++ b/pyinfra/utils/dict.py @@ -0,0 +1,5 @@ +from funcy import project + + +def safe_project(mapping, keys) -> dict: + return project(mapping, keys) if mapping else {} diff --git a/pyproject.toml b/pyproject.toml index 0eff653..7f50f86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ testcontainers = "3.4.2" docker-compose = "1.29.2" funcy = "1.17" prometheus-client = "^0.16.0" +pycryptodome = "^3.17" [tool.poetry.group.dev.dependencies] pytest = "^7.1.3" diff --git a/scripts/send_request.py b/scripts/send_request.py index 2b9c4b7..a30f725 100644 --- a/scripts/send_request.py +++ b/scripts/send_request.py @@ -7,7 +7,7 @@ import pika from pyinfra.config import get_config from pyinfra.queue.development_queue_manager import DevelopmentQueueManager -from pyinfra.storage.storages.s3 import get_s3_storage +from pyinfra.storage.storages.s3 import get_s3_storage_from_config CONFIG = get_config() logging.basicConfig() @@ -26,7 +26,7 @@ def upload_json_and_make_message_body(): object_name = f"{dossier_id}/{file_id}.{suffix}" data = gzip.compress(json.dumps(content).encode("utf-8")) - storage = get_s3_storage(CONFIG) + storage = get_s3_storage_from_config(CONFIG) if not storage.has_bucket(bucket): storage.make_bucket(bucket) storage.put_object(bucket, object_name, data) @@ -46,10 +46,10 @@ def main(): message = upload_json_and_make_message_body() - development_queue_manager.publish_request(message, pika.BasicProperties(headers={"x-tenant-id": "redaction"})) + development_queue_manager.publish_request(message, pika.BasicProperties(headers={"X-TENANT-ID": "redaction"})) logger.info(f"Put {message} on {CONFIG.request_queue}") - storage = get_s3_storage(CONFIG) + storage = get_s3_storage_from_config(CONFIG) for method_frame, properties, body in development_queue_manager._channel.consume( queue=CONFIG.response_queue, inactivity_timeout=15 ): diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index 9f76ee4..0000000 --- a/test.ipynb +++ /dev/null @@ -1,194 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'pprint.pprint'; 'pprint' is not a package", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn [10], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mloader\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FullLoader\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpprint\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpprint\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpp\u001b[39;00m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pprint.pprint'; 'pprint' is not a package" - ] - } - ], - "source": [ - "import pyinfra\n", - "import yaml\n", - "from yaml.loader import FullLoader\n", - "import pprint" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'logging': 0,\n", - " 'mock_analysis_endpoint': 'http://127.0.0.1:5000',\n", - " 'service': {'operations': {'classify': {'input': {'extension': 'cls_in.gz',\n", - " 'multi': True,\n", - " 'subdir': ''},\n", - " 'output': {'extension': 'cls_out.gz',\n", - " 'subdir': ''}},\n", - " 'default': {'input': {'extension': 'IN.gz',\n", - " 'multi': False,\n", - " 'subdir': ''},\n", - " 'output': {'extension': 'OUT.gz',\n", - " 'subdir': ''}},\n", - " 'extract': {'input': {'extension': 'extr_in.gz',\n", - " 'multi': False,\n", - " 'subdir': ''},\n", - " 'output': {'extension': 'gz',\n", - " 'subdir': 'extractions'}},\n", - " 'rotate': {'input': {'extension': 'rot_in.gz',\n", - " 'multi': False,\n", - " 'subdir': ''},\n", - " 'output': {'extension': 'rot_out.gz',\n", - " 'subdir': ''}},\n", - " 'stream_pages': {'input': {'extension': 'pgs_in.gz',\n", - " 'multi': False,\n", - " 'subdir': ''},\n", - " 'output': {'extension': 'pgs_out.gz',\n", - " 'subdir': 'pages'}},\n", - " 'upper': {'input': {'extension': 'up_in.gz',\n", - " 'multi': False,\n", - " 'subdir': ''},\n", - " 'output': {'extension': 'up_out.gz',\n", - " 'subdir': ''}}},\n", - " 'response_formatter': 'identity'},\n", - " 'storage': {'aws': {'access_key': 'AKIA4QVP6D4LCDAGYGN2',\n", - " 'endpoint': 'https://s3.amazonaws.com',\n", - " 'region': '$STORAGE_REGION|\"eu-west-1\"',\n", - " 'secret_key': '8N6H1TUHTsbvW2qMAm7zZlJ63hMqjcXAsdN7TYED'},\n", - " 'azure': {'connection_string': 'DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net'},\n", - " 'bucket': 'pyinfra-test-bucket',\n", - " 'minio': {'access_key': 'root',\n", - " 'endpoint': 'http://127.0.0.1:9000',\n", - " 'region': None,\n", - " 'secret_key': 'password'}},\n", - " 'use_docker_fixture': 1,\n", - " 'webserver': {'host': '$SERVER_HOST|\"127.0.0.1\"',\n", - " 'mode': '$SERVER_MODE|production',\n", - " 'port': '$SERVER_PORT|5000'}}\n" - ] - } - ], - "source": [ - "\n", - "# Open the file and load the file\n", - "with open('./tests/config.yml') as f:\n", - " data = yaml.load(f, Loader=FullLoader)\n", - " pprint.pprint(data)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "os.environ[\"STORAGE_BACKEND\"] = \"azure\"\n", - "\n", - "# always the same\n", - "os.environ[\"STORAGE_BUCKET_NAME\"] = \"pyinfra-test-bucket\"\n", - "\n", - "# s3\n", - "os.environ[\"STORAGE_ENDPOINT\"] = \"https://s3.amazonaws.com\"\n", - "os.environ[\"STORAGE_KEY\"] = \"AKIA4QVP6D4LCDAGYGN2\"\n", - "os.environ[\"STORAGE_SECRET\"] = \"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\"\n", - "os.environ[\"STORAGE_REGION\"] = \"eu-west-1\"\n", - "\n", - "# aks\n", - "os.environ[\"STORAGE_AZURECONNECTIONSTRING\"] = \"DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net\"" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "ename": "Exception", - "evalue": "Unknown storage backend 'aks'.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mException\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn [23], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m config \u001b[38;5;241m=\u001b[39m pyinfra\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mget_config()\n\u001b[0;32m----> 2\u001b[0m storage \u001b[38;5;241m=\u001b[39m \u001b[43mpyinfra\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstorage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_storage\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/dev/pyinfra/pyinfra/storage/storage.py:15\u001b[0m, in \u001b[0;36mget_storage\u001b[0;34m(config)\u001b[0m\n\u001b[1;32m 13\u001b[0m storage \u001b[39m=\u001b[39m get_azure_storage(config)\n\u001b[1;32m 14\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 15\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mException\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mUnknown storage backend \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mconfig\u001b[39m.\u001b[39mstorage_backend\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 17\u001b[0m \u001b[39mreturn\u001b[39;00m storage\n", - "\u001b[0;31mException\u001b[0m: Unknown storage backend 'aks'." - ] - } - ], - "source": [ - "config = pyinfra.config.get_config()\n", - "storage = pyinfra.storage.get_storage(config)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "storage.has_bucket(config.storage_bucket)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.8.13 ('pyinfra-TboPpZ8z-py3.8')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.13" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "10d7419af5ea6dfec0078ebc9d6fa1a9383fe9894853f90dc7d29a81b3de2c78" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tests/cipher_test.py b/tests/cipher_test.py new file mode 100644 index 0000000..2f0dbce --- /dev/null +++ b/tests/cipher_test.py @@ -0,0 +1,29 @@ +import pytest + +from pyinfra.utils.cipher import decrypt, encrypt + + +@pytest.fixture +def ciphertext(): + return "AAAADBRzag4/aAE2+rSekyI5phVZ1e0wwSaRkGQTLftPyVvq8vLYZzwxW48Wozc3/w==" + + +@pytest.fixture +def plaintext(): + return "connectzionString" + + +@pytest.fixture +def public_key(): + return "redaction" + + +class TestDecryption: + def test_decrypt_ciphertext(self, public_key, ciphertext, plaintext): + result = decrypt(public_key, ciphertext) + assert result == plaintext + + def test_encrypt_plaintext(self, public_key, plaintext): + ciphertext = encrypt(public_key, plaintext) + result = decrypt(public_key, ciphertext) + assert plaintext == result diff --git a/tests/conftest.py b/tests/conftest.py index 7b3fc33..053d592 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,8 +6,7 @@ import pytest import testcontainers.compose from pyinfra.config import get_config -from pyinfra.queue.queue_manager import QueueManager -from pyinfra.storage import get_storage +from pyinfra.storage import get_storage_from_config logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -30,60 +29,35 @@ def docker_compose(sleep_seconds=30): @pytest.fixture(scope="session") -def storage_config(client_name): +def test_storage_config(storage_backend, bucket_name, monitoring_enabled): config = get_config() - config.storage_backend = client_name + config.storage_backend = storage_backend + config.storage_bucket = bucket_name config.storage_azureconnectionstring = "DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net" + config.monitoring_enabled = monitoring_enabled + config.prometheus_metric_prefix = "test" + config.prometheus_port = 8080 + config.prometheus_host = "0.0.0.0" return config @pytest.fixture(scope="session") -def processing_config(storage_config, monitoring_enabled): - storage_config.monitoring_enabled = monitoring_enabled - return storage_config - - -@pytest.fixture(scope="session") -def bucket_name(storage_config): - return storage_config.storage_bucket - - -@pytest.fixture(scope="session") -def storage(storage_config): - logger.debug("Setup for storage") - storage = get_storage(storage_config) - storage.make_bucket(storage_config.storage_bucket) - storage.clear_bucket(storage_config.storage_bucket) - yield storage - logger.debug("Teardown for storage") - try: - storage.clear_bucket(storage_config.storage_bucket) - except: - pass - - -@pytest.fixture(scope="session") -def queue_config(payload_processor_type): +def test_queue_config(): config = get_config() - # FIXME: It looks like rabbitmq_heartbeat has to be greater than rabbitmq_connection_sleep. If this is expected, the - # user should not be abele to insert non working values. - config.rabbitmq_heartbeat = config.rabbitmq_connection_sleep + 1 + config.rabbitmq_connection_sleep = 2 + config.rabbitmq_heartbeat = 4 return config -@pytest.fixture(scope="session") -def queue_manager(queue_config): - queue_manager = QueueManager(queue_config) - return queue_manager - - @pytest.fixture -def request_payload(): +def payload(x_tenant_id): + x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {} return { "dossierId": "test", "fileId": "test", "targetFileExtension": "json.gz", "responseFileExtension": "json.gz", + **x_tenant_entry, } @@ -93,3 +67,17 @@ def response_payload(): "dossierId": "test", "fileId": "test", } + + +@pytest.fixture(scope="session") +def storage(test_storage_config): + logger.debug("Setup for storage") + storage = get_storage_from_config(test_storage_config) + storage.make_bucket(test_storage_config.storage_bucket) + storage.clear_bucket(test_storage_config.storage_bucket) + yield storage + logger.debug("Teardown for storage") + try: + storage.clear_bucket(test_storage_config.storage_bucket) + except: + pass diff --git a/tests/lru_test.py b/tests/lru_test.py new file mode 100644 index 0000000..9ab574f --- /dev/null +++ b/tests/lru_test.py @@ -0,0 +1,48 @@ +from functools import lru_cache + +import pytest + + +def func(callback): + return callback() + + +@pytest.fixture() +def fn(maxsize): + return lru_cache(maxsize)(func) + + +@pytest.fixture(params=[1, 2, 5]) +def maxsize(request): + return request.param + + +class Callback: + def __init__(self, x): + self.initial_x = x + self.x = x + + def __call__(self, *args, **kwargs): + self.x += 1 + return self.x + + def __hash__(self): + return hash(self.initial_x) + + +def test_adding_to_cache_within_maxsize_does_not_overwrite(fn, maxsize): + c = Callback(0) + for i in range(maxsize): + assert fn(c) == 1 + assert fn(c) == 1 + + +def test_adding_to_cache_more_than_maxsize_does_overwrite(fn, maxsize): + + callbacks = [Callback(i) for i in range(maxsize)] + + for i in range(maxsize): + assert fn(callbacks[i]) == i + 1 + + assert fn(Callback(maxsize)) == maxsize + 1 + assert fn(callbacks[0]) == 2 diff --git a/tests/monitor_test.py b/tests/monitor_test.py index 391e681..fad8e6a 100644 --- a/tests/monitor_test.py +++ b/tests/monitor_test.py @@ -4,43 +4,41 @@ import time import pytest import requests -from pyinfra.config import get_config -from pyinfra.payload_processing.monitor import get_monitor +from pyinfra.payload_processing.monitor import PrometheusMonitor @pytest.fixture(scope="class") -def monitor_config(): - config = get_config() - config.prometheus_metric_prefix = "monitor_test" - config.prometheus_port = 8000 - return config - - -@pytest.fixture(scope="class") -def prometheus_monitor(monitor_config): - return get_monitor(monitor_config) - - -@pytest.fixture -def monitored_mock_function(prometheus_monitor): +def monitored_mock_function(metric_prefix, host, port): def process(data=None): time.sleep(2) return ["result1", "result2", "result3"] - return prometheus_monitor(process) + monitor = PrometheusMonitor(metric_prefix, host, port) + return monitor(process) +@pytest.fixture +def metric_endpoint(host, port): + return f"http://{host}:{port}/prometheus" + + +@pytest.mark.parametrize("metric_prefix, host, port", [("test", "0.0.0.0", 8000)], scope="class") class TestPrometheusMonitor: - def test_prometheus_endpoint_is_available(self, prometheus_monitor, monitor_config): - resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus") + def test_prometheus_endpoint_is_available(self, metric_endpoint, monitored_mock_function): + resp = requests.get(metric_endpoint) assert resp.status_code == 200 - def test_processing_with_a_monitored_fn_increases_parameter_counter(self, monitored_mock_function, monitor_config): + def test_processing_with_a_monitored_fn_increases_parameter_counter( + self, + metric_endpoint, + metric_prefix, + monitored_mock_function, + ): monitored_mock_function(data=None) - resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus") - pattern = re.compile(r".*monitor_test_processing_time_count (\d\.\d).*") + resp = requests.get(metric_endpoint) + pattern = re.compile(rf".*{metric_prefix}_processing_time_count (\d\.\d).*") assert pattern.search(resp.text).group(1) == "1.0" monitored_mock_function(data=None) - resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus") + resp = requests.get(metric_endpoint) assert pattern.search(resp.text).group(1) == "2.0" diff --git a/tests/payload_parsing_test.py b/tests/payload_parsing_test.py new file mode 100644 index 0000000..303f301 --- /dev/null +++ b/tests/payload_parsing_test.py @@ -0,0 +1,54 @@ +import pytest + +from pyinfra.payload_processing.payload import ( + QueueMessagePayload, + QueueMessagePayloadParser, +) +from pyinfra.utils.file_extension_parsing import make_file_extension_parser + + +@pytest.fixture +def expected_parsed_payload(x_tenant_id): + return QueueMessagePayload( + dossier_id="test", + file_id="test", + x_tenant_id=x_tenant_id, + target_file_extension="json.gz", + response_file_extension="json.gz", + target_file_type="json", + target_compression_type="gz", + response_file_type="json", + response_compression_type="gz", + target_file_name="test/test.json.gz", + response_file_name="test/test.json.gz", + processing_kwargs={}, + ) + + +@pytest.fixture +def file_extension_parser(allowed_file_types, allowed_compression_types): + return make_file_extension_parser(allowed_file_types, allowed_compression_types) + + +@pytest.fixture +def payload_parser(file_extension_parser): + return QueueMessagePayloadParser(file_extension_parser) + + +@pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])]) +class TestPayload: + @pytest.mark.parametrize("x_tenant_id", [None, "klaus"]) + def test_payload_is_parsed_correctly(self, payload_parser, payload, expected_parsed_payload): + payload = payload_parser(payload) + assert payload == expected_parsed_payload + + @pytest.mark.parametrize( + "extension,expected", + [ + ("json.gz", ("json", "gz")), + ("json", ("json", None)), + ("prefix.json.gz", ("json", "gz")), + ], + ) + def test_parse_file_extension(self, file_extension_parser, extension, expected): + assert file_extension_parser(extension) == expected diff --git a/tests/processing_test.py b/tests/payload_processor_test.py similarity index 63% rename from tests/processing_test.py rename to tests/payload_processor_test.py index 18698d7..0c48ac4 100644 --- a/tests/processing_test.py +++ b/tests/payload_processor_test.py @@ -8,14 +8,6 @@ import requests from pyinfra.payload_processing.processor import make_payload_processor -@pytest.fixture(scope="session") -def file_processor_mock(): - def inner(json_file: dict): - return [json_file] - - return inner - - @pytest.fixture def target_file(): contents = {"numberOfPages": 10, "content1": "value1", "content2": "value2"} @@ -23,54 +15,60 @@ def target_file(): @pytest.fixture -def file_names(request_payload): +def file_names(payload): dossier_id, file_id, target_suffix, response_suffix = itemgetter( "dossierId", "fileId", "targetFileExtension", "responseFileExtension", - )(request_payload) + )(payload) return f"{dossier_id}/{file_id}.{target_suffix}", f"{dossier_id}/{file_id}.{response_suffix}" @pytest.fixture(scope="session") -def payload_processor(file_processor_mock, processing_config): - yield make_payload_processor(file_processor_mock, processing_config) +def payload_processor(test_storage_config): + def file_processor_mock(json_file: dict): + return [json_file] + + yield make_payload_processor(file_processor_mock, test_storage_config) -@pytest.mark.parametrize("client_name", ["s3"], scope="session") +@pytest.mark.parametrize("storage_backend", ["s3"], scope="session") +@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session") @pytest.mark.parametrize("monitoring_enabled", [True, False], scope="session") +@pytest.mark.parametrize("x_tenant_id", [None]) class TestPayloadProcessor: def test_payload_processor_yields_correct_response_and_uploads_result( self, payload_processor, storage, bucket_name, - request_payload, + payload, response_payload, target_file, file_names, ): storage.clear_bucket(bucket_name) storage.put_object(bucket_name, file_names[0], target_file) - response = payload_processor(request_payload) + response = payload_processor(payload) assert response == response_payload data_received = storage.get_object(bucket_name, file_names[1]) assert json.loads((gzip.decompress(data_received)).decode("utf-8")) == { - **request_payload, + **payload, "data": [json.loads(gzip.decompress(target_file).decode("utf-8"))], } - def test_catching_of_processing_failure(self, payload_processor, storage, bucket_name, request_payload): + def test_catching_of_processing_failure(self, payload_processor, storage, bucket_name, payload): storage.clear_bucket(bucket_name) with pytest.raises(Exception): - payload_processor(request_payload) + payload_processor(payload) - def test_prometheus_endpoint_is_available(self, processing_config): - resp = requests.get( - f"http://{processing_config.prometheus_host}:{processing_config.prometheus_port}/prometheus" - ) - assert resp.status_code == 200 \ No newline at end of file + def test_prometheus_endpoint_is_available(self, test_storage_config, monitoring_enabled, storage_backend, x_tenant_id): + if monitoring_enabled: + resp = requests.get( + f"http://{test_storage_config.prometheus_host}:{test_storage_config.prometheus_port}/prometheus" + ) + assert resp.status_code == 200 diff --git a/tests/payload_test.py b/tests/payload_test.py deleted file mode 100644 index 5f0bb36..0000000 --- a/tests/payload_test.py +++ /dev/null @@ -1,44 +0,0 @@ -import pytest - -from pyinfra.config import get_config -from pyinfra.payload_processing.payload import ( - QueueMessagePayload, - get_queue_message_payload_parser, -) -from pyinfra.utils.file_extension_parsing import make_file_extension_parser - - -@pytest.fixture(scope="session") -def payload_config(): - return get_config() - - -class TestPayload: - def test_payload_is_parsed_correctly(self, request_payload, payload_config): - parse_payload = get_queue_message_payload_parser(payload_config) - payload = parse_payload(request_payload) - assert payload == QueueMessagePayload( - dossier_id="test", - file_id="test", - target_file_extension="json.gz", - response_file_extension="json.gz", - target_file_type="json", - target_compression_type="gz", - response_file_type="json", - response_compression_type="gz", - target_file_name="test/test.json.gz", - response_file_name="test/test.json.gz", - ) - - @pytest.mark.parametrize( - "extension,expected", - [ - ("json.gz", ("json", "gz")), - ("json", ("json", None)), - ("prefix.json.gz", ("json", "gz")), - ], - ) - @pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])]) - def test_parse_file_extension(self, extension, expected, allowed_file_types, allowed_compression_types): - parse = make_file_extension_parser(allowed_file_types, allowed_compression_types) - assert parse(extension) == expected diff --git a/tests/queue_test.py b/tests/queue_manager_test.py similarity index 78% rename from tests/queue_test.py rename to tests/queue_manager_test.py index 8816e20..d6c9118 100644 --- a/tests/queue_test.py +++ b/tests/queue_manager_test.py @@ -8,15 +8,16 @@ import pika.exceptions import pytest from pyinfra.queue.development_queue_manager import DevelopmentQueueManager +from pyinfra.queue.queue_manager import QueueManager logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @pytest.fixture(scope="session") -def development_queue_manager(queue_config): - queue_config.rabbitmq_heartbeat = 7200 - development_queue_manager = DevelopmentQueueManager(queue_config) +def development_queue_manager(test_queue_config): + test_queue_config.rabbitmq_heartbeat = 7200 + development_queue_manager = DevelopmentQueueManager(test_queue_config) yield development_queue_manager logger.info("Tearing down development queue manager...") try: @@ -26,10 +27,10 @@ def development_queue_manager(queue_config): @pytest.fixture(scope="session") -def payload_processing_time(queue_config, offset=5): +def payload_processing_time(test_queue_config, offset=5): # FIXME: this implicitly tests the heartbeat when running the end-to-end test. There should be another way to test # this explicitly. - return queue_config.rabbitmq_heartbeat + offset + return test_queue_config.rabbitmq_heartbeat + offset @pytest.fixture(scope="session") @@ -48,10 +49,11 @@ def payload_processor(response_payload, payload_processing_time, payload_process @pytest.fixture(scope="session", autouse=True) -def start_queue_consumer(queue_manager, payload_processor, sleep_seconds=5): +def start_queue_consumer(test_queue_config, payload_processor, sleep_seconds=5): def consume_queue(): queue_manager.start_consuming(payload_processor) + queue_manager = QueueManager(test_queue_config) p = Process(target=consume_queue) p.start() logger.info(f"Setting up consumer, waiting for {sleep_seconds}...") @@ -65,39 +67,40 @@ def start_queue_consumer(queue_manager, payload_processor, sleep_seconds=5): def message_properties(message_headers): if not message_headers: return pika.BasicProperties(headers=None) - elif message_headers == "x-tenant-id": - return pika.BasicProperties(headers={"x-tenant-id": "redaction"}) + elif message_headers == "X-TENANT-ID": + return pika.BasicProperties(headers={"X-TENANT-ID": "redaction"}) else: raise Exception(f"Invalid {message_headers=}.") +@pytest.mark.parametrize("x_tenant_id", [None]) class TestQueueManager: # FIXME: All tests here are wonky. This is due to the implementation of running the process-blocking queue_manager # in a subprocess. It is then very hard to interact directly with the subprocess. If you have a better idea, please # refactor; the tests here are insufficient to ensure the functionality of the queue manager! @pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session") def test_message_processing_does_not_block_heartbeat( - self, development_queue_manager, request_payload, response_payload, payload_processing_time + self, development_queue_manager, payload, response_payload, payload_processing_time ): development_queue_manager.clear_queues() - development_queue_manager.publish_request(request_payload) + development_queue_manager.publish_request(payload) time.sleep(payload_processing_time + 10) _, _, body = development_queue_manager.get_response() result = json.loads(body) assert result == response_payload - @pytest.mark.parametrize("message_headers", [None, "x-tenant-id"]) + @pytest.mark.parametrize("message_headers", [None, "X-TENANT-ID"]) @pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session") def test_queue_manager_forwards_message_headers( self, development_queue_manager, - request_payload, + payload, response_payload, payload_processing_time, message_properties, ): development_queue_manager.clear_queues() - development_queue_manager.publish_request(request_payload, message_properties) + development_queue_manager.publish_request(payload, message_properties) time.sleep(payload_processing_time + 10) _, properties, _ = development_queue_manager.get_response() assert properties.headers == message_properties.headers @@ -109,12 +112,12 @@ class TestQueueManager: def test_failed_message_processing_is_handled( self, development_queue_manager, - request_payload, + payload, response_payload, payload_processing_time, ): development_queue_manager.clear_queues() - development_queue_manager.publish_request(request_payload) + development_queue_manager.publish_request(payload) time.sleep(payload_processing_time + 10) _, _, body = development_queue_manager.get_response() assert not body diff --git a/tests/storage_test.py b/tests/storage_test.py index 9d1635d..80ab4b0 100644 --- a/tests/storage_test.py +++ b/tests/storage_test.py @@ -6,7 +6,9 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -@pytest.mark.parametrize("client_name", ["azure", "s3"], scope="session") +@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="session") +@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session") +@pytest.mark.parametrize("monitoring_enabled", [False], scope="session") class TestStorage: def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name): storage.clear_bucket(bucket_name)