From d48e8108fdc0d463c89aaa0d672061ab7dca83a0 Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Wed, 22 Mar 2023 13:34:43 +0100 Subject: [PATCH] add multi-tenant storage connection 1st iteration - forward x-tenant-id from queue message header to payload processor - add functions to receive storage infos from an endpoint or the config. This enables hashing and caching of connections created from these infos - add function to initialize storage connections from storage infos - streamline and refactor tests to make them more readable and robust and to make it easier to add new tests - update payload processor with first iteration of multi tenancy storage connection support with connection caching and backwards compability --- pyinfra/config.py | 12 +- pyinfra/exception.py | 3 + pyinfra/payload_processing/monitor.py | 11 +- pyinfra/payload_processing/payload.py | 4 + pyinfra/payload_processing/processor.py | 70 +++++-- pyinfra/queue/queue_manager.py | 12 +- pyinfra/storage/__init__.py | 4 +- pyinfra/storage/storage.py | 37 +++- pyinfra/storage/storage_info.py | 89 ++++++++ pyinfra/storage/storages/azure.py | 2 +- pyinfra/storage/storages/s3.py | 2 +- pyinfra/utils/dict.py | 5 + scripts/send_request.py | 8 +- test.ipynb | 194 ------------------ tests/conftest.py | 68 +++--- tests/monitor_test.py | 44 ++-- tests/payload_parsing_test.py | 53 +++++ ...sing_test.py => payload_processor_test.py} | 44 ++-- tests/payload_test.py | 44 ---- .../{queue_test.py => queue_manager_test.py} | 33 +-- tests/storage_test.py | 4 +- 21 files changed, 353 insertions(+), 390 deletions(-) create mode 100644 pyinfra/storage/storage_info.py create mode 100644 pyinfra/utils/dict.py delete mode 100644 test.ipynb create mode 100644 tests/payload_parsing_test.py rename tests/{processing_test.py => payload_processor_test.py} (63%) delete mode 100644 tests/payload_test.py rename tests/{queue_test.py => queue_manager_test.py} (78%) diff --git a/pyinfra/config.py b/pyinfra/config.py index 0f1ec1c..0afb57c 100644 --- a/pyinfra/config.py +++ b/pyinfra/config.py @@ -28,11 +28,9 @@ class Config: "PROMETHEUS_METRIC_PREFIX", "redactmanager_research_service_parameter" ) - # Prometheus webserver address - self.prometheus_host = read_from_environment("PROMETHEUS_HOST", "127.0.0.1") - - # Prometheus webserver port - self.prometheus_port = int(read_from_environment("PROMETHEUS_PORT", 8080)) + # Prometheus webserver address and port + self.prometheus_host = "0.0.0.0" + self.prometheus_port = 8080 # RabbitMQ host address self.rabbitmq_host = read_from_environment("RABBITMQ_HOST", "localhost") @@ -94,6 +92,10 @@ class Config: self.allowed_file_types = ["json", "pdf"] self.allowed_compression_types = ["gz"] + # config for x-tenant-endpoint to receive storage connection information per tenant + self.persistence_service_public_key = "redaction" + self.persistence_service_tenant_endpoint = "http://persistence-service-v1:8080/internal-api/tenants" + # Value to see if we should write a consumer token to a file self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False") diff --git a/pyinfra/exception.py b/pyinfra/exception.py index 74a9dcb..b8d35de 100644 --- a/pyinfra/exception.py +++ b/pyinfra/exception.py @@ -1,2 +1,5 @@ class ProcessingFailure(RuntimeError): pass + +class UnknownStorageBackend(Exception): + pass \ No newline at end of file diff --git a/pyinfra/payload_processing/monitor.py b/pyinfra/payload_processing/monitor.py index b7f2530..a507bd9 100644 --- a/pyinfra/payload_processing/monitor.py +++ b/pyinfra/payload_processing/monitor.py @@ -12,7 +12,7 @@ logger = logging.getLogger() class PrometheusMonitor: - def __init__(self, prefix: str, port=8080, host="127.0.0.1"): + def __init__(self, prefix: str, host: str, port: int): """Register the monitoring metrics and start a webserver where they can be scraped at the endpoint http://{host}:{port}/prometheus @@ -23,12 +23,9 @@ class PrometheusMonitor: self.registry = CollectorRegistry() self.entity_processing_time_sum = Summary( - f"{prefix}_processing_time", - "Summed up average processing time per entity observed", + f"{prefix}_processing_time", "Summed up average processing time per entity observed", registry=self.registry ) - self.registry.register(self.entity_processing_time_sum) - start_http_server(port, host, self.registry) def __call__(self, process_fn: Callable) -> Callable: @@ -58,8 +55,8 @@ class PrometheusMonitor: return inner -def get_monitor(config: Config) -> Callable: +def get_monitor_from_config(config: Config) -> Callable: if config.monitoring_enabled: - return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_port", "prometheus_host")(config)) + return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_host", "prometheus_port")(config)) else: return identity diff --git a/pyinfra/payload_processing/payload.py b/pyinfra/payload_processing/payload.py index 48d4bd2..4381a8b 100644 --- a/pyinfra/payload_processing/payload.py +++ b/pyinfra/payload_processing/payload.py @@ -11,6 +11,8 @@ from pyinfra.utils.file_extension_parsing import make_file_extension_parser class QueueMessagePayload: dossier_id: str file_id: str + x_tenant_id: Union[str, None] + target_file_extension: str response_file_extension: str @@ -35,6 +37,7 @@ class QueueMessagePayloadParser: dossier_id, file_id, target_file_extension, response_file_extension = itemgetter( "dossierId", "fileId", "targetFileExtension", "responseFileExtension" )(payload) + x_tenant_id = payload.get("X-TENANT-ID") target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable( map(self.parse_file_extensions, [target_file_extension, response_file_extension]) @@ -46,6 +49,7 @@ class QueueMessagePayloadParser: return QueueMessagePayload( dossier_id=dossier_id, file_id=file_id, + x_tenant_id=x_tenant_id, target_file_extension=target_file_extension, response_file_extension=response_file_extension, target_file_type=target_file_type, diff --git a/pyinfra/payload_processing/processor.py b/pyinfra/payload_processing/processor.py index d908dcd..0202a0d 100644 --- a/pyinfra/payload_processing/processor.py +++ b/pyinfra/payload_processing/processor.py @@ -6,16 +6,20 @@ from typing import Callable, Union, List from funcy import compose from pyinfra.config import get_config, Config -from pyinfra.payload_processing.monitor import get_monitor +from pyinfra.payload_processing.monitor import get_monitor_from_config from pyinfra.payload_processing.payload import ( QueueMessagePayloadParser, get_queue_message_payload_parser, QueueMessagePayloadFormatter, get_queue_message_payload_formatter, ) -from pyinfra.storage import get_storage -from pyinfra.storage.storage import make_downloader, make_uploader -from pyinfra.storage.storages.interface import Storage +from pyinfra.storage.storage import make_downloader, make_uploader, get_storage_from_storage_info +from pyinfra.storage.storage_info import ( + AzureStorageInfo, + S3StorageInfo, + get_storage_info_from_config, + get_storage_info_from_endpoint, +) logger = logging.getLogger() logger.setLevel(get_config().logging_level_root) @@ -24,8 +28,8 @@ logger.setLevel(get_config().logging_level_root) class PayloadProcessor: def __init__( self, - storage: Storage, - bucket: str, + default_storage_info: Union[AzureStorageInfo, S3StorageInfo], + get_storage_info_from_tenant_id, payload_parser: QueueMessagePayloadParser, payload_formatter: QueueMessagePayloadFormatter, data_processor: Callable, @@ -33,8 +37,9 @@ class PayloadProcessor: """Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps. Args: - storage: The storage to use for downloading and uploading files - bucket: The bucket to use for downloading and uploading files + default_storage_info: The default storage info used to create the storage connection. This is only used if + x_tenant_id is not provided in the queue payload. + get_storage_info_from_tenant_id: Callable to acquire storage info from a given tenant id. payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object payload_formatter: Formatter for the storage upload result and the queue message response body data_processor: The analysis function to be called with the downloaded file @@ -46,8 +51,10 @@ class PayloadProcessor: self.format_to_queue_message_response_body = payload_formatter.format_to_queue_message_response_body self.process_data = data_processor - self.make_downloader = partial(make_downloader, storage, bucket) - self.make_uploader = partial(make_uploader, storage, bucket) + self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id + self.default_storage_info = default_storage_info + # TODO: use lru-dict + self.storages = {} def __call__(self, queue_message_payload: dict) -> dict: """Processes a queue message payload. @@ -71,8 +78,16 @@ class PayloadProcessor: payload = self.parse_payload(queue_message_payload) logger.info(f"Processing {asdict(payload)} ...") - download_file_to_process = self.make_downloader(payload.target_file_type, payload.target_compression_type) - upload_processing_result = self.make_uploader(payload.response_file_type, payload.response_compression_type) + storage_info = self._get_storage_info(payload.x_tenant_id) + bucket = storage_info.bucket_name + storage = self._get_storage(storage_info) + + download_file_to_process = make_downloader( + storage, bucket, payload.target_file_type, payload.target_compression_type + ) + upload_processing_result = make_uploader( + storage, bucket, payload.response_file_type, payload.response_compression_type + ) format_result_for_storage = partial(self.format_result_for_storage, payload) processing_pipeline = compose(format_result_for_storage, self.process_data, download_file_to_process) @@ -83,17 +98,40 @@ class PayloadProcessor: return self.format_to_queue_message_response_body(payload) + def _get_storage_info(self, x_tenant_id=None): + if x_tenant_id: + return self.get_storage_info_from_tenant_id(x_tenant_id) + return self.default_storage_info + + def _get_storage(self, storage_info): + if storage_info in self.storages: + return self.storages[storage_info] + else: + storage = get_storage_from_storage_info(storage_info) + self.storages[storage_info] = storage + return storage + def make_payload_processor(data_processor: Callable, config: Union[None, Config] = None) -> PayloadProcessor: """Produces payload processor for queue manager.""" config = config or get_config() - bucket: str = config.storage_bucket - storage: Storage = get_storage(config) - monitor = get_monitor(config) + default_storage_info: Union[AzureStorageInfo, S3StorageInfo] = get_storage_info_from_config(config) + get_storage_info_from_tenant_id = partial( + get_storage_info_from_endpoint, + config.persistence_service_public_key, + config.persistence_service_tenant_endpoint, + ) + monitor = get_monitor_from_config(config) payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config) payload_formatter: QueueMessagePayloadFormatter = get_queue_message_payload_formatter() data_processor = monitor(data_processor) - return PayloadProcessor(storage, bucket, payload_parser, payload_formatter, data_processor) + return PayloadProcessor( + default_storage_info, + get_storage_info_from_tenant_id, + payload_parser, + payload_formatter, + data_processor, + ) diff --git a/pyinfra/queue/queue_manager.py b/pyinfra/queue/queue_manager.py index ec241a0..6a0a2f2 100644 --- a/pyinfra/queue/queue_manager.py +++ b/pyinfra/queue/queue_manager.py @@ -12,6 +12,7 @@ from pika.adapters.blocking_connection import BlockingChannel from pyinfra.config import Config from pyinfra.exception import ProcessingFailure from pyinfra.payload_processing.processor import PayloadProcessor +from pyinfra.utils.dict import save_project CONFIG = Config() @@ -164,8 +165,8 @@ class QueueManager: except Exception as err: raise ProcessingFailure("QueueMessagePayload processing failed") from err - def acknowledge_message_and_publish_response(frame, properties, response_body): - response_properties = pika.BasicProperties(headers=properties.headers) if properties.headers else None + def acknowledge_message_and_publish_response(frame, headers, response_body): + response_properties = pika.BasicProperties(headers=headers) if headers else None self._channel.basic_publish("", self._output_queue, json.dumps(response_body).encode(), response_properties) self.logger.info( "Result published, acknowledging incoming message with delivery_tag %s", @@ -190,12 +191,15 @@ class QueueManager: try: self.logger.debug("Processing (%s, %s, %s)", frame, properties, body) - processing_result = process_message_body_and_await_result(json.loads(body)) + filtered_message_headers = save_project(properties.headers, ["X-TENANT-ID"]) # TODO: parametrize key? + message_body = {**json.loads(body), **filtered_message_headers} + + processing_result = process_message_body_and_await_result(message_body) self.logger.info( "Processed message with delivery_tag %s, publishing result to result-queue", frame.delivery_tag, ) - acknowledge_message_and_publish_response(frame, properties, processing_result) + acknowledge_message_and_publish_response(frame, filtered_message_headers, processing_result) except ProcessingFailure: self.logger.info( diff --git a/pyinfra/storage/__init__.py b/pyinfra/storage/__init__.py index f5d004f..dccdcda 100644 --- a/pyinfra/storage/__init__.py +++ b/pyinfra/storage/__init__.py @@ -1,3 +1,3 @@ -from pyinfra.storage.storage import get_storage +from pyinfra.storage.storage import get_storage_from_config -__all__ = ["get_storage"] +__all__ = ["get_storage_from_config"] diff --git a/pyinfra/storage/storage.py b/pyinfra/storage/storage.py index 40128b8..5e14bd5 100644 --- a/pyinfra/storage/storage.py +++ b/pyinfra/storage/storage.py @@ -1,28 +1,45 @@ from functools import lru_cache, partial -from typing import Callable +from typing import Callable, Union +from azure.storage.blob import BlobServiceClient from funcy import compose +from minio import Minio from pyinfra.config import Config -from pyinfra.storage.storages.azure import get_azure_storage -from pyinfra.storage.storages.s3 import get_s3_storage +from pyinfra.exception import UnknownStorageBackend +from pyinfra.storage.storage_info import AzureStorageInfo, S3StorageInfo, get_storage_info_from_config +from pyinfra.storage.storages.azure import AzureStorage from pyinfra.storage.storages.interface import Storage +from pyinfra.storage.storages.s3 import S3Storage from pyinfra.utils.compressing import get_decompressor, get_compressor from pyinfra.utils.encoding import get_decoder, get_encoder -def get_storage(config: Config) -> Storage: +def get_storage_from_config(config: Config) -> Storage: - if config.storage_backend == "s3": - storage = get_s3_storage(config) - elif config.storage_backend == "azure": - storage = get_azure_storage(config) - else: - raise Exception(f"Unknown storage backend '{config.storage_backend}'.") + storage_info = get_storage_info_from_config(config) + storage = get_storage_from_storage_info(storage_info) return storage +def get_storage_from_storage_info(storage_info: Union[AzureStorageInfo, S3StorageInfo]) -> Storage: + if isinstance(storage_info, AzureStorageInfo): + return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string)) + elif isinstance(storage_info, S3StorageInfo): + return S3Storage( + Minio( + secure=storage_info.secure, + endpoint=storage_info.endpoint, + access_key=storage_info.access_key, + secret_key=storage_info.secret_key, + region=storage_info.region, + ) + ) + else: + raise UnknownStorageBackend() + + def verify_existence(storage: Storage, bucket: str, file_name: str) -> str: if not storage.exists(bucket, file_name): raise FileNotFoundError(f"{file_name=} name not found on storage in {bucket=}.") diff --git a/pyinfra/storage/storage_info.py b/pyinfra/storage/storage_info.py new file mode 100644 index 0000000..80d0cdd --- /dev/null +++ b/pyinfra/storage/storage_info.py @@ -0,0 +1,89 @@ +from dataclasses import dataclass +from typing import Union + +import requests + +from pyinfra.config import Config +from pyinfra.exception import UnknownStorageBackend +from pyinfra.utils.cipher import decrypt +from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint + + +@dataclass(frozen=True) +class AzureStorageInfo: + connection_string: str + bucket_name: str + + def __hash__(self): + return hash(self.connection_string) + + +@dataclass(frozen=True) +class S3StorageInfo: + secure: bool + endpoint: str + access_key: str + secret_key: str + region: str + bucket_name: str + + def __hash__(self): + return hash((self.secure, self.endpoint, self.access_key, self.secret_key, self.region)) + + +def get_storage_info_from_endpoint( + public_key: str, endpoint: str, x_tenant_id: str +) -> Union[AzureStorageInfo, S3StorageInfo]: + # FIXME: parameterize port, host and public_key + public_key = "redaction" + resp = requests.get(f"{endpoint}/{x_tenant_id}").json() + + maybe_azure = resp.get("azureStorageConnection") + maybe_s3 = resp.get("azureStorageConnection") + assert not (maybe_azure and maybe_s3) + + if maybe_azure: + connection_string = decrypt(public_key, maybe_azure["connectionString"]) + storage_info = AzureStorageInfo( + connection_string=connection_string, + bucket_name=maybe_azure["containerName"], + ) + elif maybe_s3: + secure, endpoint = validate_and_parse_s3_endpoint(maybe_s3["endpoint"]) + secret = decrypt(public_key, maybe_s3["secret"]) + + storage_info = S3StorageInfo( + secure=secure, + endpoint=endpoint, + access_key=maybe_s3["key"], + secret_key=secret, + region=maybe_s3["region"], + bucket_name=maybe_s3, + ) + else: + raise UnknownStorageBackend() + + return storage_info + + +def get_storage_info_from_config(config: Config) -> Union[AzureStorageInfo, S3StorageInfo]: + if config.storage_backend == "s3": + storage_info = S3StorageInfo( + secure=config.storage_secure_connection, + endpoint=config.storage_endpoint, + access_key=config.storage_key, + secret_key=config.storage_secret, + region=config.storage_region, + bucket_name=config.storage_bucket, + ) + + elif config.storage_backend == "azure": + storage_info = AzureStorageInfo( + connection_string=config.storage_azureconnectionstring, + bucket_name=config.storage_bucket, + ) + + else: + raise UnknownStorageBackend(f"Unknown storage backend '{config.storage_backend}'.") + + return storage_info diff --git a/pyinfra/storage/storages/azure.py b/pyinfra/storage/storages/azure.py index aaa9ba2..f6091a4 100644 --- a/pyinfra/storage/storages/azure.py +++ b/pyinfra/storage/storages/azure.py @@ -77,5 +77,5 @@ class AzureStorage(Storage): return zip(repeat(bucket_name), map(attrgetter("name"), blobs)) -def get_azure_storage(config: Config): +def get_azure_storage_from_config(config: Config): return AzureStorage(BlobServiceClient.from_connection_string(conn_str=config.storage_azureconnectionstring)) diff --git a/pyinfra/storage/storages/s3.py b/pyinfra/storage/storages/s3.py index 5763353..3eafeff 100644 --- a/pyinfra/storage/storages/s3.py +++ b/pyinfra/storage/storages/s3.py @@ -67,7 +67,7 @@ class S3Storage(Storage): return zip(repeat(bucket_name), map(attrgetter("object_name"), objs)) -def get_s3_storage(config: Config): +def get_s3_storage_from_config(config: Config): return S3Storage( Minio( secure=config.storage_secure_connection, diff --git a/pyinfra/utils/dict.py b/pyinfra/utils/dict.py new file mode 100644 index 0000000..a220d5a --- /dev/null +++ b/pyinfra/utils/dict.py @@ -0,0 +1,5 @@ +from funcy import project + + +def save_project(mapping, keys) -> dict: + return project(mapping, keys) if mapping else {} diff --git a/scripts/send_request.py b/scripts/send_request.py index 2b9c4b7..a30f725 100644 --- a/scripts/send_request.py +++ b/scripts/send_request.py @@ -7,7 +7,7 @@ import pika from pyinfra.config import get_config from pyinfra.queue.development_queue_manager import DevelopmentQueueManager -from pyinfra.storage.storages.s3 import get_s3_storage +from pyinfra.storage.storages.s3 import get_s3_storage_from_config CONFIG = get_config() logging.basicConfig() @@ -26,7 +26,7 @@ def upload_json_and_make_message_body(): object_name = f"{dossier_id}/{file_id}.{suffix}" data = gzip.compress(json.dumps(content).encode("utf-8")) - storage = get_s3_storage(CONFIG) + storage = get_s3_storage_from_config(CONFIG) if not storage.has_bucket(bucket): storage.make_bucket(bucket) storage.put_object(bucket, object_name, data) @@ -46,10 +46,10 @@ def main(): message = upload_json_and_make_message_body() - development_queue_manager.publish_request(message, pika.BasicProperties(headers={"x-tenant-id": "redaction"})) + development_queue_manager.publish_request(message, pika.BasicProperties(headers={"X-TENANT-ID": "redaction"})) logger.info(f"Put {message} on {CONFIG.request_queue}") - storage = get_s3_storage(CONFIG) + storage = get_s3_storage_from_config(CONFIG) for method_frame, properties, body in development_queue_manager._channel.consume( queue=CONFIG.response_queue, inactivity_timeout=15 ): diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index 9f76ee4..0000000 --- a/test.ipynb +++ /dev/null @@ -1,194 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'pprint.pprint'; 'pprint' is not a package", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn [10], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mloader\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FullLoader\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpprint\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpprint\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpp\u001b[39;00m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pprint.pprint'; 'pprint' is not a package" - ] - } - ], - "source": [ - "import pyinfra\n", - "import yaml\n", - "from yaml.loader import FullLoader\n", - "import pprint" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'logging': 0,\n", - " 'mock_analysis_endpoint': 'http://127.0.0.1:5000',\n", - " 'service': {'operations': {'classify': {'input': {'extension': 'cls_in.gz',\n", - " 'multi': True,\n", - " 'subdir': ''},\n", - " 'output': {'extension': 'cls_out.gz',\n", - " 'subdir': ''}},\n", - " 'default': {'input': {'extension': 'IN.gz',\n", - " 'multi': False,\n", - " 'subdir': ''},\n", - " 'output': {'extension': 'OUT.gz',\n", - " 'subdir': ''}},\n", - " 'extract': {'input': {'extension': 'extr_in.gz',\n", - " 'multi': False,\n", - " 'subdir': ''},\n", - " 'output': {'extension': 'gz',\n", - " 'subdir': 'extractions'}},\n", - " 'rotate': {'input': {'extension': 'rot_in.gz',\n", - " 'multi': False,\n", - " 'subdir': ''},\n", - " 'output': {'extension': 'rot_out.gz',\n", - " 'subdir': ''}},\n", - " 'stream_pages': {'input': {'extension': 'pgs_in.gz',\n", - " 'multi': False,\n", - " 'subdir': ''},\n", - " 'output': {'extension': 'pgs_out.gz',\n", - " 'subdir': 'pages'}},\n", - " 'upper': {'input': {'extension': 'up_in.gz',\n", - " 'multi': False,\n", - " 'subdir': ''},\n", - " 'output': {'extension': 'up_out.gz',\n", - " 'subdir': ''}}},\n", - " 'response_formatter': 'identity'},\n", - " 'storage': {'aws': {'access_key': 'AKIA4QVP6D4LCDAGYGN2',\n", - " 'endpoint': 'https://s3.amazonaws.com',\n", - " 'region': '$STORAGE_REGION|\"eu-west-1\"',\n", - " 'secret_key': '8N6H1TUHTsbvW2qMAm7zZlJ63hMqjcXAsdN7TYED'},\n", - " 'azure': {'connection_string': 'DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net'},\n", - " 'bucket': 'pyinfra-test-bucket',\n", - " 'minio': {'access_key': 'root',\n", - " 'endpoint': 'http://127.0.0.1:9000',\n", - " 'region': None,\n", - " 'secret_key': 'password'}},\n", - " 'use_docker_fixture': 1,\n", - " 'webserver': {'host': '$SERVER_HOST|\"127.0.0.1\"',\n", - " 'mode': '$SERVER_MODE|production',\n", - " 'port': '$SERVER_PORT|5000'}}\n" - ] - } - ], - "source": [ - "\n", - "# Open the file and load the file\n", - "with open('./tests/config.yml') as f:\n", - " data = yaml.load(f, Loader=FullLoader)\n", - " pprint.pprint(data)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "os.environ[\"STORAGE_BACKEND\"] = \"azure\"\n", - "\n", - "# always the same\n", - "os.environ[\"STORAGE_BUCKET_NAME\"] = \"pyinfra-test-bucket\"\n", - "\n", - "# s3\n", - "os.environ[\"STORAGE_ENDPOINT\"] = \"https://s3.amazonaws.com\"\n", - "os.environ[\"STORAGE_KEY\"] = \"AKIA4QVP6D4LCDAGYGN2\"\n", - "os.environ[\"STORAGE_SECRET\"] = \"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\"\n", - "os.environ[\"STORAGE_REGION\"] = \"eu-west-1\"\n", - "\n", - "# aks\n", - "os.environ[\"STORAGE_AZURECONNECTIONSTRING\"] = \"DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net\"" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "ename": "Exception", - "evalue": "Unknown storage backend 'aks'.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mException\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn [23], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m config \u001b[38;5;241m=\u001b[39m pyinfra\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mget_config()\n\u001b[0;32m----> 2\u001b[0m storage \u001b[38;5;241m=\u001b[39m \u001b[43mpyinfra\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstorage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_storage\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/dev/pyinfra/pyinfra/storage/storage.py:15\u001b[0m, in \u001b[0;36mget_storage\u001b[0;34m(config)\u001b[0m\n\u001b[1;32m 13\u001b[0m storage \u001b[39m=\u001b[39m get_azure_storage(config)\n\u001b[1;32m 14\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 15\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mException\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mUnknown storage backend \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mconfig\u001b[39m.\u001b[39mstorage_backend\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 17\u001b[0m \u001b[39mreturn\u001b[39;00m storage\n", - "\u001b[0;31mException\u001b[0m: Unknown storage backend 'aks'." - ] - } - ], - "source": [ - "config = pyinfra.config.get_config()\n", - "storage = pyinfra.storage.get_storage(config)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "storage.has_bucket(config.storage_bucket)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.8.13 ('pyinfra-TboPpZ8z-py3.8')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.13" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "10d7419af5ea6dfec0078ebc9d6fa1a9383fe9894853f90dc7d29a81b3de2c78" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tests/conftest.py b/tests/conftest.py index 7b3fc33..053d592 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,8 +6,7 @@ import pytest import testcontainers.compose from pyinfra.config import get_config -from pyinfra.queue.queue_manager import QueueManager -from pyinfra.storage import get_storage +from pyinfra.storage import get_storage_from_config logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -30,60 +29,35 @@ def docker_compose(sleep_seconds=30): @pytest.fixture(scope="session") -def storage_config(client_name): +def test_storage_config(storage_backend, bucket_name, monitoring_enabled): config = get_config() - config.storage_backend = client_name + config.storage_backend = storage_backend + config.storage_bucket = bucket_name config.storage_azureconnectionstring = "DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net" + config.monitoring_enabled = monitoring_enabled + config.prometheus_metric_prefix = "test" + config.prometheus_port = 8080 + config.prometheus_host = "0.0.0.0" return config @pytest.fixture(scope="session") -def processing_config(storage_config, monitoring_enabled): - storage_config.monitoring_enabled = monitoring_enabled - return storage_config - - -@pytest.fixture(scope="session") -def bucket_name(storage_config): - return storage_config.storage_bucket - - -@pytest.fixture(scope="session") -def storage(storage_config): - logger.debug("Setup for storage") - storage = get_storage(storage_config) - storage.make_bucket(storage_config.storage_bucket) - storage.clear_bucket(storage_config.storage_bucket) - yield storage - logger.debug("Teardown for storage") - try: - storage.clear_bucket(storage_config.storage_bucket) - except: - pass - - -@pytest.fixture(scope="session") -def queue_config(payload_processor_type): +def test_queue_config(): config = get_config() - # FIXME: It looks like rabbitmq_heartbeat has to be greater than rabbitmq_connection_sleep. If this is expected, the - # user should not be abele to insert non working values. - config.rabbitmq_heartbeat = config.rabbitmq_connection_sleep + 1 + config.rabbitmq_connection_sleep = 2 + config.rabbitmq_heartbeat = 4 return config -@pytest.fixture(scope="session") -def queue_manager(queue_config): - queue_manager = QueueManager(queue_config) - return queue_manager - - @pytest.fixture -def request_payload(): +def payload(x_tenant_id): + x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {} return { "dossierId": "test", "fileId": "test", "targetFileExtension": "json.gz", "responseFileExtension": "json.gz", + **x_tenant_entry, } @@ -93,3 +67,17 @@ def response_payload(): "dossierId": "test", "fileId": "test", } + + +@pytest.fixture(scope="session") +def storage(test_storage_config): + logger.debug("Setup for storage") + storage = get_storage_from_config(test_storage_config) + storage.make_bucket(test_storage_config.storage_bucket) + storage.clear_bucket(test_storage_config.storage_bucket) + yield storage + logger.debug("Teardown for storage") + try: + storage.clear_bucket(test_storage_config.storage_bucket) + except: + pass diff --git a/tests/monitor_test.py b/tests/monitor_test.py index 391e681..fad8e6a 100644 --- a/tests/monitor_test.py +++ b/tests/monitor_test.py @@ -4,43 +4,41 @@ import time import pytest import requests -from pyinfra.config import get_config -from pyinfra.payload_processing.monitor import get_monitor +from pyinfra.payload_processing.monitor import PrometheusMonitor @pytest.fixture(scope="class") -def monitor_config(): - config = get_config() - config.prometheus_metric_prefix = "monitor_test" - config.prometheus_port = 8000 - return config - - -@pytest.fixture(scope="class") -def prometheus_monitor(monitor_config): - return get_monitor(monitor_config) - - -@pytest.fixture -def monitored_mock_function(prometheus_monitor): +def monitored_mock_function(metric_prefix, host, port): def process(data=None): time.sleep(2) return ["result1", "result2", "result3"] - return prometheus_monitor(process) + monitor = PrometheusMonitor(metric_prefix, host, port) + return monitor(process) +@pytest.fixture +def metric_endpoint(host, port): + return f"http://{host}:{port}/prometheus" + + +@pytest.mark.parametrize("metric_prefix, host, port", [("test", "0.0.0.0", 8000)], scope="class") class TestPrometheusMonitor: - def test_prometheus_endpoint_is_available(self, prometheus_monitor, monitor_config): - resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus") + def test_prometheus_endpoint_is_available(self, metric_endpoint, monitored_mock_function): + resp = requests.get(metric_endpoint) assert resp.status_code == 200 - def test_processing_with_a_monitored_fn_increases_parameter_counter(self, monitored_mock_function, monitor_config): + def test_processing_with_a_monitored_fn_increases_parameter_counter( + self, + metric_endpoint, + metric_prefix, + monitored_mock_function, + ): monitored_mock_function(data=None) - resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus") - pattern = re.compile(r".*monitor_test_processing_time_count (\d\.\d).*") + resp = requests.get(metric_endpoint) + pattern = re.compile(rf".*{metric_prefix}_processing_time_count (\d\.\d).*") assert pattern.search(resp.text).group(1) == "1.0" monitored_mock_function(data=None) - resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus") + resp = requests.get(metric_endpoint) assert pattern.search(resp.text).group(1) == "2.0" diff --git a/tests/payload_parsing_test.py b/tests/payload_parsing_test.py new file mode 100644 index 0000000..06d3510 --- /dev/null +++ b/tests/payload_parsing_test.py @@ -0,0 +1,53 @@ +import pytest + +from pyinfra.payload_processing.payload import ( + QueueMessagePayload, + QueueMessagePayloadParser, +) +from pyinfra.utils.file_extension_parsing import make_file_extension_parser + + +@pytest.fixture +def expected_parsed_payload(x_tenant_id): + return QueueMessagePayload( + dossier_id="test", + file_id="test", + x_tenant_id=x_tenant_id, + target_file_extension="json.gz", + response_file_extension="json.gz", + target_file_type="json", + target_compression_type="gz", + response_file_type="json", + response_compression_type="gz", + target_file_name="test/test.json.gz", + response_file_name="test/test.json.gz", + ) + + +@pytest.fixture +def file_extension_parser(allowed_file_types, allowed_compression_types): + return make_file_extension_parser(allowed_file_types, allowed_compression_types) + + +@pytest.fixture +def payload_parser(file_extension_parser): + return QueueMessagePayloadParser(file_extension_parser) + + +@pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])]) +class TestPayload: + @pytest.mark.parametrize("x_tenant_id", [None, "klaus"]) + def test_payload_is_parsed_correctly(self, payload_parser, payload, expected_parsed_payload): + payload = payload_parser(payload) + assert payload == expected_parsed_payload + + @pytest.mark.parametrize( + "extension,expected", + [ + ("json.gz", ("json", "gz")), + ("json", ("json", None)), + ("prefix.json.gz", ("json", "gz")), + ], + ) + def test_parse_file_extension(self, file_extension_parser, extension, expected): + assert file_extension_parser(extension) == expected diff --git a/tests/processing_test.py b/tests/payload_processor_test.py similarity index 63% rename from tests/processing_test.py rename to tests/payload_processor_test.py index 18698d7..0c48ac4 100644 --- a/tests/processing_test.py +++ b/tests/payload_processor_test.py @@ -8,14 +8,6 @@ import requests from pyinfra.payload_processing.processor import make_payload_processor -@pytest.fixture(scope="session") -def file_processor_mock(): - def inner(json_file: dict): - return [json_file] - - return inner - - @pytest.fixture def target_file(): contents = {"numberOfPages": 10, "content1": "value1", "content2": "value2"} @@ -23,54 +15,60 @@ def target_file(): @pytest.fixture -def file_names(request_payload): +def file_names(payload): dossier_id, file_id, target_suffix, response_suffix = itemgetter( "dossierId", "fileId", "targetFileExtension", "responseFileExtension", - )(request_payload) + )(payload) return f"{dossier_id}/{file_id}.{target_suffix}", f"{dossier_id}/{file_id}.{response_suffix}" @pytest.fixture(scope="session") -def payload_processor(file_processor_mock, processing_config): - yield make_payload_processor(file_processor_mock, processing_config) +def payload_processor(test_storage_config): + def file_processor_mock(json_file: dict): + return [json_file] + + yield make_payload_processor(file_processor_mock, test_storage_config) -@pytest.mark.parametrize("client_name", ["s3"], scope="session") +@pytest.mark.parametrize("storage_backend", ["s3"], scope="session") +@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session") @pytest.mark.parametrize("monitoring_enabled", [True, False], scope="session") +@pytest.mark.parametrize("x_tenant_id", [None]) class TestPayloadProcessor: def test_payload_processor_yields_correct_response_and_uploads_result( self, payload_processor, storage, bucket_name, - request_payload, + payload, response_payload, target_file, file_names, ): storage.clear_bucket(bucket_name) storage.put_object(bucket_name, file_names[0], target_file) - response = payload_processor(request_payload) + response = payload_processor(payload) assert response == response_payload data_received = storage.get_object(bucket_name, file_names[1]) assert json.loads((gzip.decompress(data_received)).decode("utf-8")) == { - **request_payload, + **payload, "data": [json.loads(gzip.decompress(target_file).decode("utf-8"))], } - def test_catching_of_processing_failure(self, payload_processor, storage, bucket_name, request_payload): + def test_catching_of_processing_failure(self, payload_processor, storage, bucket_name, payload): storage.clear_bucket(bucket_name) with pytest.raises(Exception): - payload_processor(request_payload) + payload_processor(payload) - def test_prometheus_endpoint_is_available(self, processing_config): - resp = requests.get( - f"http://{processing_config.prometheus_host}:{processing_config.prometheus_port}/prometheus" - ) - assert resp.status_code == 200 \ No newline at end of file + def test_prometheus_endpoint_is_available(self, test_storage_config, monitoring_enabled, storage_backend, x_tenant_id): + if monitoring_enabled: + resp = requests.get( + f"http://{test_storage_config.prometheus_host}:{test_storage_config.prometheus_port}/prometheus" + ) + assert resp.status_code == 200 diff --git a/tests/payload_test.py b/tests/payload_test.py deleted file mode 100644 index 5f0bb36..0000000 --- a/tests/payload_test.py +++ /dev/null @@ -1,44 +0,0 @@ -import pytest - -from pyinfra.config import get_config -from pyinfra.payload_processing.payload import ( - QueueMessagePayload, - get_queue_message_payload_parser, -) -from pyinfra.utils.file_extension_parsing import make_file_extension_parser - - -@pytest.fixture(scope="session") -def payload_config(): - return get_config() - - -class TestPayload: - def test_payload_is_parsed_correctly(self, request_payload, payload_config): - parse_payload = get_queue_message_payload_parser(payload_config) - payload = parse_payload(request_payload) - assert payload == QueueMessagePayload( - dossier_id="test", - file_id="test", - target_file_extension="json.gz", - response_file_extension="json.gz", - target_file_type="json", - target_compression_type="gz", - response_file_type="json", - response_compression_type="gz", - target_file_name="test/test.json.gz", - response_file_name="test/test.json.gz", - ) - - @pytest.mark.parametrize( - "extension,expected", - [ - ("json.gz", ("json", "gz")), - ("json", ("json", None)), - ("prefix.json.gz", ("json", "gz")), - ], - ) - @pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])]) - def test_parse_file_extension(self, extension, expected, allowed_file_types, allowed_compression_types): - parse = make_file_extension_parser(allowed_file_types, allowed_compression_types) - assert parse(extension) == expected diff --git a/tests/queue_test.py b/tests/queue_manager_test.py similarity index 78% rename from tests/queue_test.py rename to tests/queue_manager_test.py index 8816e20..d6c9118 100644 --- a/tests/queue_test.py +++ b/tests/queue_manager_test.py @@ -8,15 +8,16 @@ import pika.exceptions import pytest from pyinfra.queue.development_queue_manager import DevelopmentQueueManager +from pyinfra.queue.queue_manager import QueueManager logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @pytest.fixture(scope="session") -def development_queue_manager(queue_config): - queue_config.rabbitmq_heartbeat = 7200 - development_queue_manager = DevelopmentQueueManager(queue_config) +def development_queue_manager(test_queue_config): + test_queue_config.rabbitmq_heartbeat = 7200 + development_queue_manager = DevelopmentQueueManager(test_queue_config) yield development_queue_manager logger.info("Tearing down development queue manager...") try: @@ -26,10 +27,10 @@ def development_queue_manager(queue_config): @pytest.fixture(scope="session") -def payload_processing_time(queue_config, offset=5): +def payload_processing_time(test_queue_config, offset=5): # FIXME: this implicitly tests the heartbeat when running the end-to-end test. There should be another way to test # this explicitly. - return queue_config.rabbitmq_heartbeat + offset + return test_queue_config.rabbitmq_heartbeat + offset @pytest.fixture(scope="session") @@ -48,10 +49,11 @@ def payload_processor(response_payload, payload_processing_time, payload_process @pytest.fixture(scope="session", autouse=True) -def start_queue_consumer(queue_manager, payload_processor, sleep_seconds=5): +def start_queue_consumer(test_queue_config, payload_processor, sleep_seconds=5): def consume_queue(): queue_manager.start_consuming(payload_processor) + queue_manager = QueueManager(test_queue_config) p = Process(target=consume_queue) p.start() logger.info(f"Setting up consumer, waiting for {sleep_seconds}...") @@ -65,39 +67,40 @@ def start_queue_consumer(queue_manager, payload_processor, sleep_seconds=5): def message_properties(message_headers): if not message_headers: return pika.BasicProperties(headers=None) - elif message_headers == "x-tenant-id": - return pika.BasicProperties(headers={"x-tenant-id": "redaction"}) + elif message_headers == "X-TENANT-ID": + return pika.BasicProperties(headers={"X-TENANT-ID": "redaction"}) else: raise Exception(f"Invalid {message_headers=}.") +@pytest.mark.parametrize("x_tenant_id", [None]) class TestQueueManager: # FIXME: All tests here are wonky. This is due to the implementation of running the process-blocking queue_manager # in a subprocess. It is then very hard to interact directly with the subprocess. If you have a better idea, please # refactor; the tests here are insufficient to ensure the functionality of the queue manager! @pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session") def test_message_processing_does_not_block_heartbeat( - self, development_queue_manager, request_payload, response_payload, payload_processing_time + self, development_queue_manager, payload, response_payload, payload_processing_time ): development_queue_manager.clear_queues() - development_queue_manager.publish_request(request_payload) + development_queue_manager.publish_request(payload) time.sleep(payload_processing_time + 10) _, _, body = development_queue_manager.get_response() result = json.loads(body) assert result == response_payload - @pytest.mark.parametrize("message_headers", [None, "x-tenant-id"]) + @pytest.mark.parametrize("message_headers", [None, "X-TENANT-ID"]) @pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session") def test_queue_manager_forwards_message_headers( self, development_queue_manager, - request_payload, + payload, response_payload, payload_processing_time, message_properties, ): development_queue_manager.clear_queues() - development_queue_manager.publish_request(request_payload, message_properties) + development_queue_manager.publish_request(payload, message_properties) time.sleep(payload_processing_time + 10) _, properties, _ = development_queue_manager.get_response() assert properties.headers == message_properties.headers @@ -109,12 +112,12 @@ class TestQueueManager: def test_failed_message_processing_is_handled( self, development_queue_manager, - request_payload, + payload, response_payload, payload_processing_time, ): development_queue_manager.clear_queues() - development_queue_manager.publish_request(request_payload) + development_queue_manager.publish_request(payload) time.sleep(payload_processing_time + 10) _, _, body = development_queue_manager.get_response() assert not body diff --git a/tests/storage_test.py b/tests/storage_test.py index 9d1635d..80ab4b0 100644 --- a/tests/storage_test.py +++ b/tests/storage_test.py @@ -6,7 +6,9 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -@pytest.mark.parametrize("client_name", ["azure", "s3"], scope="session") +@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="session") +@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session") +@pytest.mark.parametrize("monitoring_enabled", [False], scope="session") class TestStorage: def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name): storage.clear_bucket(bucket_name)