add multi-tenant storage connection 1st iteration

- forward x-tenant-id from queue message header to
payload processor
- add functions to receive storage infos from an
endpoint or the config. This enables hashing and
caching of connections created from these infos
- add function to initialize storage connections
from storage infos
- streamline and refactor tests to make them more
readable and robust and to make it easier to add
 new tests
- update payload processor with first iteration
of multi tenancy storage connection support
with connection caching and backwards compability
This commit is contained in:
Julius Unverfehrt 2023-03-22 13:34:43 +01:00
parent 52c047c47b
commit d48e8108fd
21 changed files with 353 additions and 390 deletions

View File

@ -28,11 +28,9 @@ class Config:
"PROMETHEUS_METRIC_PREFIX", "redactmanager_research_service_parameter"
)
# Prometheus webserver address
self.prometheus_host = read_from_environment("PROMETHEUS_HOST", "127.0.0.1")
# Prometheus webserver port
self.prometheus_port = int(read_from_environment("PROMETHEUS_PORT", 8080))
# Prometheus webserver address and port
self.prometheus_host = "0.0.0.0"
self.prometheus_port = 8080
# RabbitMQ host address
self.rabbitmq_host = read_from_environment("RABBITMQ_HOST", "localhost")
@ -94,6 +92,10 @@ class Config:
self.allowed_file_types = ["json", "pdf"]
self.allowed_compression_types = ["gz"]
# config for x-tenant-endpoint to receive storage connection information per tenant
self.persistence_service_public_key = "redaction"
self.persistence_service_tenant_endpoint = "http://persistence-service-v1:8080/internal-api/tenants"
# Value to see if we should write a consumer token to a file
self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False")

View File

@ -1,2 +1,5 @@
class ProcessingFailure(RuntimeError):
pass
class UnknownStorageBackend(Exception):
pass

View File

@ -12,7 +12,7 @@ logger = logging.getLogger()
class PrometheusMonitor:
def __init__(self, prefix: str, port=8080, host="127.0.0.1"):
def __init__(self, prefix: str, host: str, port: int):
"""Register the monitoring metrics and start a webserver where they can be scraped at the endpoint
http://{host}:{port}/prometheus
@ -23,12 +23,9 @@ class PrometheusMonitor:
self.registry = CollectorRegistry()
self.entity_processing_time_sum = Summary(
f"{prefix}_processing_time",
"Summed up average processing time per entity observed",
f"{prefix}_processing_time", "Summed up average processing time per entity observed", registry=self.registry
)
self.registry.register(self.entity_processing_time_sum)
start_http_server(port, host, self.registry)
def __call__(self, process_fn: Callable) -> Callable:
@ -58,8 +55,8 @@ class PrometheusMonitor:
return inner
def get_monitor(config: Config) -> Callable:
def get_monitor_from_config(config: Config) -> Callable:
if config.monitoring_enabled:
return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_port", "prometheus_host")(config))
return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_host", "prometheus_port")(config))
else:
return identity

View File

@ -11,6 +11,8 @@ from pyinfra.utils.file_extension_parsing import make_file_extension_parser
class QueueMessagePayload:
dossier_id: str
file_id: str
x_tenant_id: Union[str, None]
target_file_extension: str
response_file_extension: str
@ -35,6 +37,7 @@ class QueueMessagePayloadParser:
dossier_id, file_id, target_file_extension, response_file_extension = itemgetter(
"dossierId", "fileId", "targetFileExtension", "responseFileExtension"
)(payload)
x_tenant_id = payload.get("X-TENANT-ID")
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
map(self.parse_file_extensions, [target_file_extension, response_file_extension])
@ -46,6 +49,7 @@ class QueueMessagePayloadParser:
return QueueMessagePayload(
dossier_id=dossier_id,
file_id=file_id,
x_tenant_id=x_tenant_id,
target_file_extension=target_file_extension,
response_file_extension=response_file_extension,
target_file_type=target_file_type,

View File

@ -6,16 +6,20 @@ from typing import Callable, Union, List
from funcy import compose
from pyinfra.config import get_config, Config
from pyinfra.payload_processing.monitor import get_monitor
from pyinfra.payload_processing.monitor import get_monitor_from_config
from pyinfra.payload_processing.payload import (
QueueMessagePayloadParser,
get_queue_message_payload_parser,
QueueMessagePayloadFormatter,
get_queue_message_payload_formatter,
)
from pyinfra.storage import get_storage
from pyinfra.storage.storage import make_downloader, make_uploader
from pyinfra.storage.storages.interface import Storage
from pyinfra.storage.storage import make_downloader, make_uploader, get_storage_from_storage_info
from pyinfra.storage.storage_info import (
AzureStorageInfo,
S3StorageInfo,
get_storage_info_from_config,
get_storage_info_from_endpoint,
)
logger = logging.getLogger()
logger.setLevel(get_config().logging_level_root)
@ -24,8 +28,8 @@ logger.setLevel(get_config().logging_level_root)
class PayloadProcessor:
def __init__(
self,
storage: Storage,
bucket: str,
default_storage_info: Union[AzureStorageInfo, S3StorageInfo],
get_storage_info_from_tenant_id,
payload_parser: QueueMessagePayloadParser,
payload_formatter: QueueMessagePayloadFormatter,
data_processor: Callable,
@ -33,8 +37,9 @@ class PayloadProcessor:
"""Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps.
Args:
storage: The storage to use for downloading and uploading files
bucket: The bucket to use for downloading and uploading files
default_storage_info: The default storage info used to create the storage connection. This is only used if
x_tenant_id is not provided in the queue payload.
get_storage_info_from_tenant_id: Callable to acquire storage info from a given tenant id.
payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object
payload_formatter: Formatter for the storage upload result and the queue message response body
data_processor: The analysis function to be called with the downloaded file
@ -46,8 +51,10 @@ class PayloadProcessor:
self.format_to_queue_message_response_body = payload_formatter.format_to_queue_message_response_body
self.process_data = data_processor
self.make_downloader = partial(make_downloader, storage, bucket)
self.make_uploader = partial(make_uploader, storage, bucket)
self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id
self.default_storage_info = default_storage_info
# TODO: use lru-dict
self.storages = {}
def __call__(self, queue_message_payload: dict) -> dict:
"""Processes a queue message payload.
@ -71,8 +78,16 @@ class PayloadProcessor:
payload = self.parse_payload(queue_message_payload)
logger.info(f"Processing {asdict(payload)} ...")
download_file_to_process = self.make_downloader(payload.target_file_type, payload.target_compression_type)
upload_processing_result = self.make_uploader(payload.response_file_type, payload.response_compression_type)
storage_info = self._get_storage_info(payload.x_tenant_id)
bucket = storage_info.bucket_name
storage = self._get_storage(storage_info)
download_file_to_process = make_downloader(
storage, bucket, payload.target_file_type, payload.target_compression_type
)
upload_processing_result = make_uploader(
storage, bucket, payload.response_file_type, payload.response_compression_type
)
format_result_for_storage = partial(self.format_result_for_storage, payload)
processing_pipeline = compose(format_result_for_storage, self.process_data, download_file_to_process)
@ -83,17 +98,40 @@ class PayloadProcessor:
return self.format_to_queue_message_response_body(payload)
def _get_storage_info(self, x_tenant_id=None):
if x_tenant_id:
return self.get_storage_info_from_tenant_id(x_tenant_id)
return self.default_storage_info
def _get_storage(self, storage_info):
if storage_info in self.storages:
return self.storages[storage_info]
else:
storage = get_storage_from_storage_info(storage_info)
self.storages[storage_info] = storage
return storage
def make_payload_processor(data_processor: Callable, config: Union[None, Config] = None) -> PayloadProcessor:
"""Produces payload processor for queue manager."""
config = config or get_config()
bucket: str = config.storage_bucket
storage: Storage = get_storage(config)
monitor = get_monitor(config)
default_storage_info: Union[AzureStorageInfo, S3StorageInfo] = get_storage_info_from_config(config)
get_storage_info_from_tenant_id = partial(
get_storage_info_from_endpoint,
config.persistence_service_public_key,
config.persistence_service_tenant_endpoint,
)
monitor = get_monitor_from_config(config)
payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config)
payload_formatter: QueueMessagePayloadFormatter = get_queue_message_payload_formatter()
data_processor = monitor(data_processor)
return PayloadProcessor(storage, bucket, payload_parser, payload_formatter, data_processor)
return PayloadProcessor(
default_storage_info,
get_storage_info_from_tenant_id,
payload_parser,
payload_formatter,
data_processor,
)

View File

@ -12,6 +12,7 @@ from pika.adapters.blocking_connection import BlockingChannel
from pyinfra.config import Config
from pyinfra.exception import ProcessingFailure
from pyinfra.payload_processing.processor import PayloadProcessor
from pyinfra.utils.dict import save_project
CONFIG = Config()
@ -164,8 +165,8 @@ class QueueManager:
except Exception as err:
raise ProcessingFailure("QueueMessagePayload processing failed") from err
def acknowledge_message_and_publish_response(frame, properties, response_body):
response_properties = pika.BasicProperties(headers=properties.headers) if properties.headers else None
def acknowledge_message_and_publish_response(frame, headers, response_body):
response_properties = pika.BasicProperties(headers=headers) if headers else None
self._channel.basic_publish("", self._output_queue, json.dumps(response_body).encode(), response_properties)
self.logger.info(
"Result published, acknowledging incoming message with delivery_tag %s",
@ -190,12 +191,15 @@ class QueueManager:
try:
self.logger.debug("Processing (%s, %s, %s)", frame, properties, body)
processing_result = process_message_body_and_await_result(json.loads(body))
filtered_message_headers = save_project(properties.headers, ["X-TENANT-ID"]) # TODO: parametrize key?
message_body = {**json.loads(body), **filtered_message_headers}
processing_result = process_message_body_and_await_result(message_body)
self.logger.info(
"Processed message with delivery_tag %s, publishing result to result-queue",
frame.delivery_tag,
)
acknowledge_message_and_publish_response(frame, properties, processing_result)
acknowledge_message_and_publish_response(frame, filtered_message_headers, processing_result)
except ProcessingFailure:
self.logger.info(

View File

@ -1,3 +1,3 @@
from pyinfra.storage.storage import get_storage
from pyinfra.storage.storage import get_storage_from_config
__all__ = ["get_storage"]
__all__ = ["get_storage_from_config"]

View File

@ -1,28 +1,45 @@
from functools import lru_cache, partial
from typing import Callable
from typing import Callable, Union
from azure.storage.blob import BlobServiceClient
from funcy import compose
from minio import Minio
from pyinfra.config import Config
from pyinfra.storage.storages.azure import get_azure_storage
from pyinfra.storage.storages.s3 import get_s3_storage
from pyinfra.exception import UnknownStorageBackend
from pyinfra.storage.storage_info import AzureStorageInfo, S3StorageInfo, get_storage_info_from_config
from pyinfra.storage.storages.azure import AzureStorage
from pyinfra.storage.storages.interface import Storage
from pyinfra.storage.storages.s3 import S3Storage
from pyinfra.utils.compressing import get_decompressor, get_compressor
from pyinfra.utils.encoding import get_decoder, get_encoder
def get_storage(config: Config) -> Storage:
def get_storage_from_config(config: Config) -> Storage:
if config.storage_backend == "s3":
storage = get_s3_storage(config)
elif config.storage_backend == "azure":
storage = get_azure_storage(config)
else:
raise Exception(f"Unknown storage backend '{config.storage_backend}'.")
storage_info = get_storage_info_from_config(config)
storage = get_storage_from_storage_info(storage_info)
return storage
def get_storage_from_storage_info(storage_info: Union[AzureStorageInfo, S3StorageInfo]) -> Storage:
if isinstance(storage_info, AzureStorageInfo):
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string))
elif isinstance(storage_info, S3StorageInfo):
return S3Storage(
Minio(
secure=storage_info.secure,
endpoint=storage_info.endpoint,
access_key=storage_info.access_key,
secret_key=storage_info.secret_key,
region=storage_info.region,
)
)
else:
raise UnknownStorageBackend()
def verify_existence(storage: Storage, bucket: str, file_name: str) -> str:
if not storage.exists(bucket, file_name):
raise FileNotFoundError(f"{file_name=} name not found on storage in {bucket=}.")

View File

@ -0,0 +1,89 @@
from dataclasses import dataclass
from typing import Union
import requests
from pyinfra.config import Config
from pyinfra.exception import UnknownStorageBackend
from pyinfra.utils.cipher import decrypt
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
@dataclass(frozen=True)
class AzureStorageInfo:
connection_string: str
bucket_name: str
def __hash__(self):
return hash(self.connection_string)
@dataclass(frozen=True)
class S3StorageInfo:
secure: bool
endpoint: str
access_key: str
secret_key: str
region: str
bucket_name: str
def __hash__(self):
return hash((self.secure, self.endpoint, self.access_key, self.secret_key, self.region))
def get_storage_info_from_endpoint(
public_key: str, endpoint: str, x_tenant_id: str
) -> Union[AzureStorageInfo, S3StorageInfo]:
# FIXME: parameterize port, host and public_key
public_key = "redaction"
resp = requests.get(f"{endpoint}/{x_tenant_id}").json()
maybe_azure = resp.get("azureStorageConnection")
maybe_s3 = resp.get("azureStorageConnection")
assert not (maybe_azure and maybe_s3)
if maybe_azure:
connection_string = decrypt(public_key, maybe_azure["connectionString"])
storage_info = AzureStorageInfo(
connection_string=connection_string,
bucket_name=maybe_azure["containerName"],
)
elif maybe_s3:
secure, endpoint = validate_and_parse_s3_endpoint(maybe_s3["endpoint"])
secret = decrypt(public_key, maybe_s3["secret"])
storage_info = S3StorageInfo(
secure=secure,
endpoint=endpoint,
access_key=maybe_s3["key"],
secret_key=secret,
region=maybe_s3["region"],
bucket_name=maybe_s3,
)
else:
raise UnknownStorageBackend()
return storage_info
def get_storage_info_from_config(config: Config) -> Union[AzureStorageInfo, S3StorageInfo]:
if config.storage_backend == "s3":
storage_info = S3StorageInfo(
secure=config.storage_secure_connection,
endpoint=config.storage_endpoint,
access_key=config.storage_key,
secret_key=config.storage_secret,
region=config.storage_region,
bucket_name=config.storage_bucket,
)
elif config.storage_backend == "azure":
storage_info = AzureStorageInfo(
connection_string=config.storage_azureconnectionstring,
bucket_name=config.storage_bucket,
)
else:
raise UnknownStorageBackend(f"Unknown storage backend '{config.storage_backend}'.")
return storage_info

View File

@ -77,5 +77,5 @@ class AzureStorage(Storage):
return zip(repeat(bucket_name), map(attrgetter("name"), blobs))
def get_azure_storage(config: Config):
def get_azure_storage_from_config(config: Config):
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=config.storage_azureconnectionstring))

View File

@ -67,7 +67,7 @@ class S3Storage(Storage):
return zip(repeat(bucket_name), map(attrgetter("object_name"), objs))
def get_s3_storage(config: Config):
def get_s3_storage_from_config(config: Config):
return S3Storage(
Minio(
secure=config.storage_secure_connection,

5
pyinfra/utils/dict.py Normal file
View File

@ -0,0 +1,5 @@
from funcy import project
def save_project(mapping, keys) -> dict:
return project(mapping, keys) if mapping else {}

View File

@ -7,7 +7,7 @@ import pika
from pyinfra.config import get_config
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
from pyinfra.storage.storages.s3 import get_s3_storage
from pyinfra.storage.storages.s3 import get_s3_storage_from_config
CONFIG = get_config()
logging.basicConfig()
@ -26,7 +26,7 @@ def upload_json_and_make_message_body():
object_name = f"{dossier_id}/{file_id}.{suffix}"
data = gzip.compress(json.dumps(content).encode("utf-8"))
storage = get_s3_storage(CONFIG)
storage = get_s3_storage_from_config(CONFIG)
if not storage.has_bucket(bucket):
storage.make_bucket(bucket)
storage.put_object(bucket, object_name, data)
@ -46,10 +46,10 @@ def main():
message = upload_json_and_make_message_body()
development_queue_manager.publish_request(message, pika.BasicProperties(headers={"x-tenant-id": "redaction"}))
development_queue_manager.publish_request(message, pika.BasicProperties(headers={"X-TENANT-ID": "redaction"}))
logger.info(f"Put {message} on {CONFIG.request_queue}")
storage = get_s3_storage(CONFIG)
storage = get_s3_storage_from_config(CONFIG)
for method_frame, properties, body in development_queue_manager._channel.consume(
queue=CONFIG.response_queue, inactivity_timeout=15
):

View File

@ -1,194 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'pprint.pprint'; 'pprint' is not a package",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [10], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mloader\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FullLoader\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpprint\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpprint\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpp\u001b[39;00m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pprint.pprint'; 'pprint' is not a package"
]
}
],
"source": [
"import pyinfra\n",
"import yaml\n",
"from yaml.loader import FullLoader\n",
"import pprint"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'logging': 0,\n",
" 'mock_analysis_endpoint': 'http://127.0.0.1:5000',\n",
" 'service': {'operations': {'classify': {'input': {'extension': 'cls_in.gz',\n",
" 'multi': True,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'cls_out.gz',\n",
" 'subdir': ''}},\n",
" 'default': {'input': {'extension': 'IN.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'OUT.gz',\n",
" 'subdir': ''}},\n",
" 'extract': {'input': {'extension': 'extr_in.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'gz',\n",
" 'subdir': 'extractions'}},\n",
" 'rotate': {'input': {'extension': 'rot_in.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'rot_out.gz',\n",
" 'subdir': ''}},\n",
" 'stream_pages': {'input': {'extension': 'pgs_in.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'pgs_out.gz',\n",
" 'subdir': 'pages'}},\n",
" 'upper': {'input': {'extension': 'up_in.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'up_out.gz',\n",
" 'subdir': ''}}},\n",
" 'response_formatter': 'identity'},\n",
" 'storage': {'aws': {'access_key': 'AKIA4QVP6D4LCDAGYGN2',\n",
" 'endpoint': 'https://s3.amazonaws.com',\n",
" 'region': '$STORAGE_REGION|\"eu-west-1\"',\n",
" 'secret_key': '8N6H1TUHTsbvW2qMAm7zZlJ63hMqjcXAsdN7TYED'},\n",
" 'azure': {'connection_string': 'DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net'},\n",
" 'bucket': 'pyinfra-test-bucket',\n",
" 'minio': {'access_key': 'root',\n",
" 'endpoint': 'http://127.0.0.1:9000',\n",
" 'region': None,\n",
" 'secret_key': 'password'}},\n",
" 'use_docker_fixture': 1,\n",
" 'webserver': {'host': '$SERVER_HOST|\"127.0.0.1\"',\n",
" 'mode': '$SERVER_MODE|production',\n",
" 'port': '$SERVER_PORT|5000'}}\n"
]
}
],
"source": [
"\n",
"# Open the file and load the file\n",
"with open('./tests/config.yml') as f:\n",
" data = yaml.load(f, Loader=FullLoader)\n",
" pprint.pprint(data)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"STORAGE_BACKEND\"] = \"azure\"\n",
"\n",
"# always the same\n",
"os.environ[\"STORAGE_BUCKET_NAME\"] = \"pyinfra-test-bucket\"\n",
"\n",
"# s3\n",
"os.environ[\"STORAGE_ENDPOINT\"] = \"https://s3.amazonaws.com\"\n",
"os.environ[\"STORAGE_KEY\"] = \"AKIA4QVP6D4LCDAGYGN2\"\n",
"os.environ[\"STORAGE_SECRET\"] = \"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\"\n",
"os.environ[\"STORAGE_REGION\"] = \"eu-west-1\"\n",
"\n",
"# aks\n",
"os.environ[\"STORAGE_AZURECONNECTIONSTRING\"] = \"DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net\""
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"ename": "Exception",
"evalue": "Unknown storage backend 'aks'.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [23], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m config \u001b[38;5;241m=\u001b[39m pyinfra\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mget_config()\n\u001b[0;32m----> 2\u001b[0m storage \u001b[38;5;241m=\u001b[39m \u001b[43mpyinfra\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstorage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_storage\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/dev/pyinfra/pyinfra/storage/storage.py:15\u001b[0m, in \u001b[0;36mget_storage\u001b[0;34m(config)\u001b[0m\n\u001b[1;32m 13\u001b[0m storage \u001b[39m=\u001b[39m get_azure_storage(config)\n\u001b[1;32m 14\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 15\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mException\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mUnknown storage backend \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mconfig\u001b[39m.\u001b[39mstorage_backend\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 17\u001b[0m \u001b[39mreturn\u001b[39;00m storage\n",
"\u001b[0;31mException\u001b[0m: Unknown storage backend 'aks'."
]
}
],
"source": [
"config = pyinfra.config.get_config()\n",
"storage = pyinfra.storage.get_storage(config)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"storage.has_bucket(config.storage_bucket)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.13 ('pyinfra-TboPpZ8z-py3.8')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "10d7419af5ea6dfec0078ebc9d6fa1a9383fe9894853f90dc7d29a81b3de2c78"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -6,8 +6,7 @@ import pytest
import testcontainers.compose
from pyinfra.config import get_config
from pyinfra.queue.queue_manager import QueueManager
from pyinfra.storage import get_storage
from pyinfra.storage import get_storage_from_config
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@ -30,60 +29,35 @@ def docker_compose(sleep_seconds=30):
@pytest.fixture(scope="session")
def storage_config(client_name):
def test_storage_config(storage_backend, bucket_name, monitoring_enabled):
config = get_config()
config.storage_backend = client_name
config.storage_backend = storage_backend
config.storage_bucket = bucket_name
config.storage_azureconnectionstring = "DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net"
config.monitoring_enabled = monitoring_enabled
config.prometheus_metric_prefix = "test"
config.prometheus_port = 8080
config.prometheus_host = "0.0.0.0"
return config
@pytest.fixture(scope="session")
def processing_config(storage_config, monitoring_enabled):
storage_config.monitoring_enabled = monitoring_enabled
return storage_config
@pytest.fixture(scope="session")
def bucket_name(storage_config):
return storage_config.storage_bucket
@pytest.fixture(scope="session")
def storage(storage_config):
logger.debug("Setup for storage")
storage = get_storage(storage_config)
storage.make_bucket(storage_config.storage_bucket)
storage.clear_bucket(storage_config.storage_bucket)
yield storage
logger.debug("Teardown for storage")
try:
storage.clear_bucket(storage_config.storage_bucket)
except:
pass
@pytest.fixture(scope="session")
def queue_config(payload_processor_type):
def test_queue_config():
config = get_config()
# FIXME: It looks like rabbitmq_heartbeat has to be greater than rabbitmq_connection_sleep. If this is expected, the
# user should not be abele to insert non working values.
config.rabbitmq_heartbeat = config.rabbitmq_connection_sleep + 1
config.rabbitmq_connection_sleep = 2
config.rabbitmq_heartbeat = 4
return config
@pytest.fixture(scope="session")
def queue_manager(queue_config):
queue_manager = QueueManager(queue_config)
return queue_manager
@pytest.fixture
def request_payload():
def payload(x_tenant_id):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
return {
"dossierId": "test",
"fileId": "test",
"targetFileExtension": "json.gz",
"responseFileExtension": "json.gz",
**x_tenant_entry,
}
@ -93,3 +67,17 @@ def response_payload():
"dossierId": "test",
"fileId": "test",
}
@pytest.fixture(scope="session")
def storage(test_storage_config):
logger.debug("Setup for storage")
storage = get_storage_from_config(test_storage_config)
storage.make_bucket(test_storage_config.storage_bucket)
storage.clear_bucket(test_storage_config.storage_bucket)
yield storage
logger.debug("Teardown for storage")
try:
storage.clear_bucket(test_storage_config.storage_bucket)
except:
pass

View File

@ -4,43 +4,41 @@ import time
import pytest
import requests
from pyinfra.config import get_config
from pyinfra.payload_processing.monitor import get_monitor
from pyinfra.payload_processing.monitor import PrometheusMonitor
@pytest.fixture(scope="class")
def monitor_config():
config = get_config()
config.prometheus_metric_prefix = "monitor_test"
config.prometheus_port = 8000
return config
@pytest.fixture(scope="class")
def prometheus_monitor(monitor_config):
return get_monitor(monitor_config)
@pytest.fixture
def monitored_mock_function(prometheus_monitor):
def monitored_mock_function(metric_prefix, host, port):
def process(data=None):
time.sleep(2)
return ["result1", "result2", "result3"]
return prometheus_monitor(process)
monitor = PrometheusMonitor(metric_prefix, host, port)
return monitor(process)
@pytest.fixture
def metric_endpoint(host, port):
return f"http://{host}:{port}/prometheus"
@pytest.mark.parametrize("metric_prefix, host, port", [("test", "0.0.0.0", 8000)], scope="class")
class TestPrometheusMonitor:
def test_prometheus_endpoint_is_available(self, prometheus_monitor, monitor_config):
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus")
def test_prometheus_endpoint_is_available(self, metric_endpoint, monitored_mock_function):
resp = requests.get(metric_endpoint)
assert resp.status_code == 200
def test_processing_with_a_monitored_fn_increases_parameter_counter(self, monitored_mock_function, monitor_config):
def test_processing_with_a_monitored_fn_increases_parameter_counter(
self,
metric_endpoint,
metric_prefix,
monitored_mock_function,
):
monitored_mock_function(data=None)
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus")
pattern = re.compile(r".*monitor_test_processing_time_count (\d\.\d).*")
resp = requests.get(metric_endpoint)
pattern = re.compile(rf".*{metric_prefix}_processing_time_count (\d\.\d).*")
assert pattern.search(resp.text).group(1) == "1.0"
monitored_mock_function(data=None)
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus")
resp = requests.get(metric_endpoint)
assert pattern.search(resp.text).group(1) == "2.0"

View File

@ -0,0 +1,53 @@
import pytest
from pyinfra.payload_processing.payload import (
QueueMessagePayload,
QueueMessagePayloadParser,
)
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
@pytest.fixture
def expected_parsed_payload(x_tenant_id):
return QueueMessagePayload(
dossier_id="test",
file_id="test",
x_tenant_id=x_tenant_id,
target_file_extension="json.gz",
response_file_extension="json.gz",
target_file_type="json",
target_compression_type="gz",
response_file_type="json",
response_compression_type="gz",
target_file_name="test/test.json.gz",
response_file_name="test/test.json.gz",
)
@pytest.fixture
def file_extension_parser(allowed_file_types, allowed_compression_types):
return make_file_extension_parser(allowed_file_types, allowed_compression_types)
@pytest.fixture
def payload_parser(file_extension_parser):
return QueueMessagePayloadParser(file_extension_parser)
@pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])])
class TestPayload:
@pytest.mark.parametrize("x_tenant_id", [None, "klaus"])
def test_payload_is_parsed_correctly(self, payload_parser, payload, expected_parsed_payload):
payload = payload_parser(payload)
assert payload == expected_parsed_payload
@pytest.mark.parametrize(
"extension,expected",
[
("json.gz", ("json", "gz")),
("json", ("json", None)),
("prefix.json.gz", ("json", "gz")),
],
)
def test_parse_file_extension(self, file_extension_parser, extension, expected):
assert file_extension_parser(extension) == expected

View File

@ -8,14 +8,6 @@ import requests
from pyinfra.payload_processing.processor import make_payload_processor
@pytest.fixture(scope="session")
def file_processor_mock():
def inner(json_file: dict):
return [json_file]
return inner
@pytest.fixture
def target_file():
contents = {"numberOfPages": 10, "content1": "value1", "content2": "value2"}
@ -23,54 +15,60 @@ def target_file():
@pytest.fixture
def file_names(request_payload):
def file_names(payload):
dossier_id, file_id, target_suffix, response_suffix = itemgetter(
"dossierId",
"fileId",
"targetFileExtension",
"responseFileExtension",
)(request_payload)
)(payload)
return f"{dossier_id}/{file_id}.{target_suffix}", f"{dossier_id}/{file_id}.{response_suffix}"
@pytest.fixture(scope="session")
def payload_processor(file_processor_mock, processing_config):
yield make_payload_processor(file_processor_mock, processing_config)
def payload_processor(test_storage_config):
def file_processor_mock(json_file: dict):
return [json_file]
yield make_payload_processor(file_processor_mock, test_storage_config)
@pytest.mark.parametrize("client_name", ["s3"], scope="session")
@pytest.mark.parametrize("storage_backend", ["s3"], scope="session")
@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session")
@pytest.mark.parametrize("monitoring_enabled", [True, False], scope="session")
@pytest.mark.parametrize("x_tenant_id", [None])
class TestPayloadProcessor:
def test_payload_processor_yields_correct_response_and_uploads_result(
self,
payload_processor,
storage,
bucket_name,
request_payload,
payload,
response_payload,
target_file,
file_names,
):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, file_names[0], target_file)
response = payload_processor(request_payload)
response = payload_processor(payload)
assert response == response_payload
data_received = storage.get_object(bucket_name, file_names[1])
assert json.loads((gzip.decompress(data_received)).decode("utf-8")) == {
**request_payload,
**payload,
"data": [json.loads(gzip.decompress(target_file).decode("utf-8"))],
}
def test_catching_of_processing_failure(self, payload_processor, storage, bucket_name, request_payload):
def test_catching_of_processing_failure(self, payload_processor, storage, bucket_name, payload):
storage.clear_bucket(bucket_name)
with pytest.raises(Exception):
payload_processor(request_payload)
payload_processor(payload)
def test_prometheus_endpoint_is_available(self, processing_config):
resp = requests.get(
f"http://{processing_config.prometheus_host}:{processing_config.prometheus_port}/prometheus"
)
assert resp.status_code == 200
def test_prometheus_endpoint_is_available(self, test_storage_config, monitoring_enabled, storage_backend, x_tenant_id):
if monitoring_enabled:
resp = requests.get(
f"http://{test_storage_config.prometheus_host}:{test_storage_config.prometheus_port}/prometheus"
)
assert resp.status_code == 200

View File

@ -1,44 +0,0 @@
import pytest
from pyinfra.config import get_config
from pyinfra.payload_processing.payload import (
QueueMessagePayload,
get_queue_message_payload_parser,
)
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
@pytest.fixture(scope="session")
def payload_config():
return get_config()
class TestPayload:
def test_payload_is_parsed_correctly(self, request_payload, payload_config):
parse_payload = get_queue_message_payload_parser(payload_config)
payload = parse_payload(request_payload)
assert payload == QueueMessagePayload(
dossier_id="test",
file_id="test",
target_file_extension="json.gz",
response_file_extension="json.gz",
target_file_type="json",
target_compression_type="gz",
response_file_type="json",
response_compression_type="gz",
target_file_name="test/test.json.gz",
response_file_name="test/test.json.gz",
)
@pytest.mark.parametrize(
"extension,expected",
[
("json.gz", ("json", "gz")),
("json", ("json", None)),
("prefix.json.gz", ("json", "gz")),
],
)
@pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])])
def test_parse_file_extension(self, extension, expected, allowed_file_types, allowed_compression_types):
parse = make_file_extension_parser(allowed_file_types, allowed_compression_types)
assert parse(extension) == expected

View File

@ -8,15 +8,16 @@ import pika.exceptions
import pytest
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
from pyinfra.queue.queue_manager import QueueManager
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@pytest.fixture(scope="session")
def development_queue_manager(queue_config):
queue_config.rabbitmq_heartbeat = 7200
development_queue_manager = DevelopmentQueueManager(queue_config)
def development_queue_manager(test_queue_config):
test_queue_config.rabbitmq_heartbeat = 7200
development_queue_manager = DevelopmentQueueManager(test_queue_config)
yield development_queue_manager
logger.info("Tearing down development queue manager...")
try:
@ -26,10 +27,10 @@ def development_queue_manager(queue_config):
@pytest.fixture(scope="session")
def payload_processing_time(queue_config, offset=5):
def payload_processing_time(test_queue_config, offset=5):
# FIXME: this implicitly tests the heartbeat when running the end-to-end test. There should be another way to test
# this explicitly.
return queue_config.rabbitmq_heartbeat + offset
return test_queue_config.rabbitmq_heartbeat + offset
@pytest.fixture(scope="session")
@ -48,10 +49,11 @@ def payload_processor(response_payload, payload_processing_time, payload_process
@pytest.fixture(scope="session", autouse=True)
def start_queue_consumer(queue_manager, payload_processor, sleep_seconds=5):
def start_queue_consumer(test_queue_config, payload_processor, sleep_seconds=5):
def consume_queue():
queue_manager.start_consuming(payload_processor)
queue_manager = QueueManager(test_queue_config)
p = Process(target=consume_queue)
p.start()
logger.info(f"Setting up consumer, waiting for {sleep_seconds}...")
@ -65,39 +67,40 @@ def start_queue_consumer(queue_manager, payload_processor, sleep_seconds=5):
def message_properties(message_headers):
if not message_headers:
return pika.BasicProperties(headers=None)
elif message_headers == "x-tenant-id":
return pika.BasicProperties(headers={"x-tenant-id": "redaction"})
elif message_headers == "X-TENANT-ID":
return pika.BasicProperties(headers={"X-TENANT-ID": "redaction"})
else:
raise Exception(f"Invalid {message_headers=}.")
@pytest.mark.parametrize("x_tenant_id", [None])
class TestQueueManager:
# FIXME: All tests here are wonky. This is due to the implementation of running the process-blocking queue_manager
# in a subprocess. It is then very hard to interact directly with the subprocess. If you have a better idea, please
# refactor; the tests here are insufficient to ensure the functionality of the queue manager!
@pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session")
def test_message_processing_does_not_block_heartbeat(
self, development_queue_manager, request_payload, response_payload, payload_processing_time
self, development_queue_manager, payload, response_payload, payload_processing_time
):
development_queue_manager.clear_queues()
development_queue_manager.publish_request(request_payload)
development_queue_manager.publish_request(payload)
time.sleep(payload_processing_time + 10)
_, _, body = development_queue_manager.get_response()
result = json.loads(body)
assert result == response_payload
@pytest.mark.parametrize("message_headers", [None, "x-tenant-id"])
@pytest.mark.parametrize("message_headers", [None, "X-TENANT-ID"])
@pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session")
def test_queue_manager_forwards_message_headers(
self,
development_queue_manager,
request_payload,
payload,
response_payload,
payload_processing_time,
message_properties,
):
development_queue_manager.clear_queues()
development_queue_manager.publish_request(request_payload, message_properties)
development_queue_manager.publish_request(payload, message_properties)
time.sleep(payload_processing_time + 10)
_, properties, _ = development_queue_manager.get_response()
assert properties.headers == message_properties.headers
@ -109,12 +112,12 @@ class TestQueueManager:
def test_failed_message_processing_is_handled(
self,
development_queue_manager,
request_payload,
payload,
response_payload,
payload_processing_time,
):
development_queue_manager.clear_queues()
development_queue_manager.publish_request(request_payload)
development_queue_manager.publish_request(payload)
time.sleep(payload_processing_time + 10)
_, _, body = development_queue_manager.get_response()
assert not body

View File

@ -6,7 +6,9 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("client_name", ["azure", "s3"], scope="session")
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="session")
@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session")
@pytest.mark.parametrize("monitoring_enabled", [False], scope="session")
class TestStorage:
def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name):
storage.clear_bucket(bucket_name)