add multi-tenant storage connection 1st iteration

- forward x-tenant-id from queue message header to
payload processor
- add functions to receive storage infos from an
endpoint or the config. This enables hashing and
caching of connections created from these infos
- add function to initialize storage connections
from storage infos
- streamline and refactor tests to make them more
readable and robust and to make it easier to add
 new tests
- update payload processor with first iteration
of multi tenancy storage connection support
with connection caching and backwards compability
This commit is contained in:
Julius Unverfehrt 2023-03-22 13:34:43 +01:00
parent 52c047c47b
commit d48e8108fd
21 changed files with 353 additions and 390 deletions

View File

@ -28,11 +28,9 @@ class Config:
"PROMETHEUS_METRIC_PREFIX", "redactmanager_research_service_parameter" "PROMETHEUS_METRIC_PREFIX", "redactmanager_research_service_parameter"
) )
# Prometheus webserver address # Prometheus webserver address and port
self.prometheus_host = read_from_environment("PROMETHEUS_HOST", "127.0.0.1") self.prometheus_host = "0.0.0.0"
self.prometheus_port = 8080
# Prometheus webserver port
self.prometheus_port = int(read_from_environment("PROMETHEUS_PORT", 8080))
# RabbitMQ host address # RabbitMQ host address
self.rabbitmq_host = read_from_environment("RABBITMQ_HOST", "localhost") self.rabbitmq_host = read_from_environment("RABBITMQ_HOST", "localhost")
@ -94,6 +92,10 @@ class Config:
self.allowed_file_types = ["json", "pdf"] self.allowed_file_types = ["json", "pdf"]
self.allowed_compression_types = ["gz"] self.allowed_compression_types = ["gz"]
# config for x-tenant-endpoint to receive storage connection information per tenant
self.persistence_service_public_key = "redaction"
self.persistence_service_tenant_endpoint = "http://persistence-service-v1:8080/internal-api/tenants"
# Value to see if we should write a consumer token to a file # Value to see if we should write a consumer token to a file
self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False") self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False")

View File

@ -1,2 +1,5 @@
class ProcessingFailure(RuntimeError): class ProcessingFailure(RuntimeError):
pass pass
class UnknownStorageBackend(Exception):
pass

View File

@ -12,7 +12,7 @@ logger = logging.getLogger()
class PrometheusMonitor: class PrometheusMonitor:
def __init__(self, prefix: str, port=8080, host="127.0.0.1"): def __init__(self, prefix: str, host: str, port: int):
"""Register the monitoring metrics and start a webserver where they can be scraped at the endpoint """Register the monitoring metrics and start a webserver where they can be scraped at the endpoint
http://{host}:{port}/prometheus http://{host}:{port}/prometheus
@ -23,12 +23,9 @@ class PrometheusMonitor:
self.registry = CollectorRegistry() self.registry = CollectorRegistry()
self.entity_processing_time_sum = Summary( self.entity_processing_time_sum = Summary(
f"{prefix}_processing_time", f"{prefix}_processing_time", "Summed up average processing time per entity observed", registry=self.registry
"Summed up average processing time per entity observed",
) )
self.registry.register(self.entity_processing_time_sum)
start_http_server(port, host, self.registry) start_http_server(port, host, self.registry)
def __call__(self, process_fn: Callable) -> Callable: def __call__(self, process_fn: Callable) -> Callable:
@ -58,8 +55,8 @@ class PrometheusMonitor:
return inner return inner
def get_monitor(config: Config) -> Callable: def get_monitor_from_config(config: Config) -> Callable:
if config.monitoring_enabled: if config.monitoring_enabled:
return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_port", "prometheus_host")(config)) return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_host", "prometheus_port")(config))
else: else:
return identity return identity

View File

@ -11,6 +11,8 @@ from pyinfra.utils.file_extension_parsing import make_file_extension_parser
class QueueMessagePayload: class QueueMessagePayload:
dossier_id: str dossier_id: str
file_id: str file_id: str
x_tenant_id: Union[str, None]
target_file_extension: str target_file_extension: str
response_file_extension: str response_file_extension: str
@ -35,6 +37,7 @@ class QueueMessagePayloadParser:
dossier_id, file_id, target_file_extension, response_file_extension = itemgetter( dossier_id, file_id, target_file_extension, response_file_extension = itemgetter(
"dossierId", "fileId", "targetFileExtension", "responseFileExtension" "dossierId", "fileId", "targetFileExtension", "responseFileExtension"
)(payload) )(payload)
x_tenant_id = payload.get("X-TENANT-ID")
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable( target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
map(self.parse_file_extensions, [target_file_extension, response_file_extension]) map(self.parse_file_extensions, [target_file_extension, response_file_extension])
@ -46,6 +49,7 @@ class QueueMessagePayloadParser:
return QueueMessagePayload( return QueueMessagePayload(
dossier_id=dossier_id, dossier_id=dossier_id,
file_id=file_id, file_id=file_id,
x_tenant_id=x_tenant_id,
target_file_extension=target_file_extension, target_file_extension=target_file_extension,
response_file_extension=response_file_extension, response_file_extension=response_file_extension,
target_file_type=target_file_type, target_file_type=target_file_type,

View File

@ -6,16 +6,20 @@ from typing import Callable, Union, List
from funcy import compose from funcy import compose
from pyinfra.config import get_config, Config from pyinfra.config import get_config, Config
from pyinfra.payload_processing.monitor import get_monitor from pyinfra.payload_processing.monitor import get_monitor_from_config
from pyinfra.payload_processing.payload import ( from pyinfra.payload_processing.payload import (
QueueMessagePayloadParser, QueueMessagePayloadParser,
get_queue_message_payload_parser, get_queue_message_payload_parser,
QueueMessagePayloadFormatter, QueueMessagePayloadFormatter,
get_queue_message_payload_formatter, get_queue_message_payload_formatter,
) )
from pyinfra.storage import get_storage from pyinfra.storage.storage import make_downloader, make_uploader, get_storage_from_storage_info
from pyinfra.storage.storage import make_downloader, make_uploader from pyinfra.storage.storage_info import (
from pyinfra.storage.storages.interface import Storage AzureStorageInfo,
S3StorageInfo,
get_storage_info_from_config,
get_storage_info_from_endpoint,
)
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(get_config().logging_level_root) logger.setLevel(get_config().logging_level_root)
@ -24,8 +28,8 @@ logger.setLevel(get_config().logging_level_root)
class PayloadProcessor: class PayloadProcessor:
def __init__( def __init__(
self, self,
storage: Storage, default_storage_info: Union[AzureStorageInfo, S3StorageInfo],
bucket: str, get_storage_info_from_tenant_id,
payload_parser: QueueMessagePayloadParser, payload_parser: QueueMessagePayloadParser,
payload_formatter: QueueMessagePayloadFormatter, payload_formatter: QueueMessagePayloadFormatter,
data_processor: Callable, data_processor: Callable,
@ -33,8 +37,9 @@ class PayloadProcessor:
"""Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps. """Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps.
Args: Args:
storage: The storage to use for downloading and uploading files default_storage_info: The default storage info used to create the storage connection. This is only used if
bucket: The bucket to use for downloading and uploading files x_tenant_id is not provided in the queue payload.
get_storage_info_from_tenant_id: Callable to acquire storage info from a given tenant id.
payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object
payload_formatter: Formatter for the storage upload result and the queue message response body payload_formatter: Formatter for the storage upload result and the queue message response body
data_processor: The analysis function to be called with the downloaded file data_processor: The analysis function to be called with the downloaded file
@ -46,8 +51,10 @@ class PayloadProcessor:
self.format_to_queue_message_response_body = payload_formatter.format_to_queue_message_response_body self.format_to_queue_message_response_body = payload_formatter.format_to_queue_message_response_body
self.process_data = data_processor self.process_data = data_processor
self.make_downloader = partial(make_downloader, storage, bucket) self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id
self.make_uploader = partial(make_uploader, storage, bucket) self.default_storage_info = default_storage_info
# TODO: use lru-dict
self.storages = {}
def __call__(self, queue_message_payload: dict) -> dict: def __call__(self, queue_message_payload: dict) -> dict:
"""Processes a queue message payload. """Processes a queue message payload.
@ -71,8 +78,16 @@ class PayloadProcessor:
payload = self.parse_payload(queue_message_payload) payload = self.parse_payload(queue_message_payload)
logger.info(f"Processing {asdict(payload)} ...") logger.info(f"Processing {asdict(payload)} ...")
download_file_to_process = self.make_downloader(payload.target_file_type, payload.target_compression_type) storage_info = self._get_storage_info(payload.x_tenant_id)
upload_processing_result = self.make_uploader(payload.response_file_type, payload.response_compression_type) bucket = storage_info.bucket_name
storage = self._get_storage(storage_info)
download_file_to_process = make_downloader(
storage, bucket, payload.target_file_type, payload.target_compression_type
)
upload_processing_result = make_uploader(
storage, bucket, payload.response_file_type, payload.response_compression_type
)
format_result_for_storage = partial(self.format_result_for_storage, payload) format_result_for_storage = partial(self.format_result_for_storage, payload)
processing_pipeline = compose(format_result_for_storage, self.process_data, download_file_to_process) processing_pipeline = compose(format_result_for_storage, self.process_data, download_file_to_process)
@ -83,17 +98,40 @@ class PayloadProcessor:
return self.format_to_queue_message_response_body(payload) return self.format_to_queue_message_response_body(payload)
def _get_storage_info(self, x_tenant_id=None):
if x_tenant_id:
return self.get_storage_info_from_tenant_id(x_tenant_id)
return self.default_storage_info
def _get_storage(self, storage_info):
if storage_info in self.storages:
return self.storages[storage_info]
else:
storage = get_storage_from_storage_info(storage_info)
self.storages[storage_info] = storage
return storage
def make_payload_processor(data_processor: Callable, config: Union[None, Config] = None) -> PayloadProcessor: def make_payload_processor(data_processor: Callable, config: Union[None, Config] = None) -> PayloadProcessor:
"""Produces payload processor for queue manager.""" """Produces payload processor for queue manager."""
config = config or get_config() config = config or get_config()
bucket: str = config.storage_bucket default_storage_info: Union[AzureStorageInfo, S3StorageInfo] = get_storage_info_from_config(config)
storage: Storage = get_storage(config) get_storage_info_from_tenant_id = partial(
monitor = get_monitor(config) get_storage_info_from_endpoint,
config.persistence_service_public_key,
config.persistence_service_tenant_endpoint,
)
monitor = get_monitor_from_config(config)
payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config) payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config)
payload_formatter: QueueMessagePayloadFormatter = get_queue_message_payload_formatter() payload_formatter: QueueMessagePayloadFormatter = get_queue_message_payload_formatter()
data_processor = monitor(data_processor) data_processor = monitor(data_processor)
return PayloadProcessor(storage, bucket, payload_parser, payload_formatter, data_processor) return PayloadProcessor(
default_storage_info,
get_storage_info_from_tenant_id,
payload_parser,
payload_formatter,
data_processor,
)

View File

@ -12,6 +12,7 @@ from pika.adapters.blocking_connection import BlockingChannel
from pyinfra.config import Config from pyinfra.config import Config
from pyinfra.exception import ProcessingFailure from pyinfra.exception import ProcessingFailure
from pyinfra.payload_processing.processor import PayloadProcessor from pyinfra.payload_processing.processor import PayloadProcessor
from pyinfra.utils.dict import save_project
CONFIG = Config() CONFIG = Config()
@ -164,8 +165,8 @@ class QueueManager:
except Exception as err: except Exception as err:
raise ProcessingFailure("QueueMessagePayload processing failed") from err raise ProcessingFailure("QueueMessagePayload processing failed") from err
def acknowledge_message_and_publish_response(frame, properties, response_body): def acknowledge_message_and_publish_response(frame, headers, response_body):
response_properties = pika.BasicProperties(headers=properties.headers) if properties.headers else None response_properties = pika.BasicProperties(headers=headers) if headers else None
self._channel.basic_publish("", self._output_queue, json.dumps(response_body).encode(), response_properties) self._channel.basic_publish("", self._output_queue, json.dumps(response_body).encode(), response_properties)
self.logger.info( self.logger.info(
"Result published, acknowledging incoming message with delivery_tag %s", "Result published, acknowledging incoming message with delivery_tag %s",
@ -190,12 +191,15 @@ class QueueManager:
try: try:
self.logger.debug("Processing (%s, %s, %s)", frame, properties, body) self.logger.debug("Processing (%s, %s, %s)", frame, properties, body)
processing_result = process_message_body_and_await_result(json.loads(body)) filtered_message_headers = save_project(properties.headers, ["X-TENANT-ID"]) # TODO: parametrize key?
message_body = {**json.loads(body), **filtered_message_headers}
processing_result = process_message_body_and_await_result(message_body)
self.logger.info( self.logger.info(
"Processed message with delivery_tag %s, publishing result to result-queue", "Processed message with delivery_tag %s, publishing result to result-queue",
frame.delivery_tag, frame.delivery_tag,
) )
acknowledge_message_and_publish_response(frame, properties, processing_result) acknowledge_message_and_publish_response(frame, filtered_message_headers, processing_result)
except ProcessingFailure: except ProcessingFailure:
self.logger.info( self.logger.info(

View File

@ -1,3 +1,3 @@
from pyinfra.storage.storage import get_storage from pyinfra.storage.storage import get_storage_from_config
__all__ = ["get_storage"] __all__ = ["get_storage_from_config"]

View File

@ -1,28 +1,45 @@
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Callable from typing import Callable, Union
from azure.storage.blob import BlobServiceClient
from funcy import compose from funcy import compose
from minio import Minio
from pyinfra.config import Config from pyinfra.config import Config
from pyinfra.storage.storages.azure import get_azure_storage from pyinfra.exception import UnknownStorageBackend
from pyinfra.storage.storages.s3 import get_s3_storage from pyinfra.storage.storage_info import AzureStorageInfo, S3StorageInfo, get_storage_info_from_config
from pyinfra.storage.storages.azure import AzureStorage
from pyinfra.storage.storages.interface import Storage from pyinfra.storage.storages.interface import Storage
from pyinfra.storage.storages.s3 import S3Storage
from pyinfra.utils.compressing import get_decompressor, get_compressor from pyinfra.utils.compressing import get_decompressor, get_compressor
from pyinfra.utils.encoding import get_decoder, get_encoder from pyinfra.utils.encoding import get_decoder, get_encoder
def get_storage(config: Config) -> Storage: def get_storage_from_config(config: Config) -> Storage:
if config.storage_backend == "s3": storage_info = get_storage_info_from_config(config)
storage = get_s3_storage(config) storage = get_storage_from_storage_info(storage_info)
elif config.storage_backend == "azure":
storage = get_azure_storage(config)
else:
raise Exception(f"Unknown storage backend '{config.storage_backend}'.")
return storage return storage
def get_storage_from_storage_info(storage_info: Union[AzureStorageInfo, S3StorageInfo]) -> Storage:
if isinstance(storage_info, AzureStorageInfo):
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string))
elif isinstance(storage_info, S3StorageInfo):
return S3Storage(
Minio(
secure=storage_info.secure,
endpoint=storage_info.endpoint,
access_key=storage_info.access_key,
secret_key=storage_info.secret_key,
region=storage_info.region,
)
)
else:
raise UnknownStorageBackend()
def verify_existence(storage: Storage, bucket: str, file_name: str) -> str: def verify_existence(storage: Storage, bucket: str, file_name: str) -> str:
if not storage.exists(bucket, file_name): if not storage.exists(bucket, file_name):
raise FileNotFoundError(f"{file_name=} name not found on storage in {bucket=}.") raise FileNotFoundError(f"{file_name=} name not found on storage in {bucket=}.")

View File

@ -0,0 +1,89 @@
from dataclasses import dataclass
from typing import Union
import requests
from pyinfra.config import Config
from pyinfra.exception import UnknownStorageBackend
from pyinfra.utils.cipher import decrypt
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
@dataclass(frozen=True)
class AzureStorageInfo:
connection_string: str
bucket_name: str
def __hash__(self):
return hash(self.connection_string)
@dataclass(frozen=True)
class S3StorageInfo:
secure: bool
endpoint: str
access_key: str
secret_key: str
region: str
bucket_name: str
def __hash__(self):
return hash((self.secure, self.endpoint, self.access_key, self.secret_key, self.region))
def get_storage_info_from_endpoint(
public_key: str, endpoint: str, x_tenant_id: str
) -> Union[AzureStorageInfo, S3StorageInfo]:
# FIXME: parameterize port, host and public_key
public_key = "redaction"
resp = requests.get(f"{endpoint}/{x_tenant_id}").json()
maybe_azure = resp.get("azureStorageConnection")
maybe_s3 = resp.get("azureStorageConnection")
assert not (maybe_azure and maybe_s3)
if maybe_azure:
connection_string = decrypt(public_key, maybe_azure["connectionString"])
storage_info = AzureStorageInfo(
connection_string=connection_string,
bucket_name=maybe_azure["containerName"],
)
elif maybe_s3:
secure, endpoint = validate_and_parse_s3_endpoint(maybe_s3["endpoint"])
secret = decrypt(public_key, maybe_s3["secret"])
storage_info = S3StorageInfo(
secure=secure,
endpoint=endpoint,
access_key=maybe_s3["key"],
secret_key=secret,
region=maybe_s3["region"],
bucket_name=maybe_s3,
)
else:
raise UnknownStorageBackend()
return storage_info
def get_storage_info_from_config(config: Config) -> Union[AzureStorageInfo, S3StorageInfo]:
if config.storage_backend == "s3":
storage_info = S3StorageInfo(
secure=config.storage_secure_connection,
endpoint=config.storage_endpoint,
access_key=config.storage_key,
secret_key=config.storage_secret,
region=config.storage_region,
bucket_name=config.storage_bucket,
)
elif config.storage_backend == "azure":
storage_info = AzureStorageInfo(
connection_string=config.storage_azureconnectionstring,
bucket_name=config.storage_bucket,
)
else:
raise UnknownStorageBackend(f"Unknown storage backend '{config.storage_backend}'.")
return storage_info

View File

@ -77,5 +77,5 @@ class AzureStorage(Storage):
return zip(repeat(bucket_name), map(attrgetter("name"), blobs)) return zip(repeat(bucket_name), map(attrgetter("name"), blobs))
def get_azure_storage(config: Config): def get_azure_storage_from_config(config: Config):
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=config.storage_azureconnectionstring)) return AzureStorage(BlobServiceClient.from_connection_string(conn_str=config.storage_azureconnectionstring))

View File

@ -67,7 +67,7 @@ class S3Storage(Storage):
return zip(repeat(bucket_name), map(attrgetter("object_name"), objs)) return zip(repeat(bucket_name), map(attrgetter("object_name"), objs))
def get_s3_storage(config: Config): def get_s3_storage_from_config(config: Config):
return S3Storage( return S3Storage(
Minio( Minio(
secure=config.storage_secure_connection, secure=config.storage_secure_connection,

5
pyinfra/utils/dict.py Normal file
View File

@ -0,0 +1,5 @@
from funcy import project
def save_project(mapping, keys) -> dict:
return project(mapping, keys) if mapping else {}

View File

@ -7,7 +7,7 @@ import pika
from pyinfra.config import get_config from pyinfra.config import get_config
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
from pyinfra.storage.storages.s3 import get_s3_storage from pyinfra.storage.storages.s3 import get_s3_storage_from_config
CONFIG = get_config() CONFIG = get_config()
logging.basicConfig() logging.basicConfig()
@ -26,7 +26,7 @@ def upload_json_and_make_message_body():
object_name = f"{dossier_id}/{file_id}.{suffix}" object_name = f"{dossier_id}/{file_id}.{suffix}"
data = gzip.compress(json.dumps(content).encode("utf-8")) data = gzip.compress(json.dumps(content).encode("utf-8"))
storage = get_s3_storage(CONFIG) storage = get_s3_storage_from_config(CONFIG)
if not storage.has_bucket(bucket): if not storage.has_bucket(bucket):
storage.make_bucket(bucket) storage.make_bucket(bucket)
storage.put_object(bucket, object_name, data) storage.put_object(bucket, object_name, data)
@ -46,10 +46,10 @@ def main():
message = upload_json_and_make_message_body() message = upload_json_and_make_message_body()
development_queue_manager.publish_request(message, pika.BasicProperties(headers={"x-tenant-id": "redaction"})) development_queue_manager.publish_request(message, pika.BasicProperties(headers={"X-TENANT-ID": "redaction"}))
logger.info(f"Put {message} on {CONFIG.request_queue}") logger.info(f"Put {message} on {CONFIG.request_queue}")
storage = get_s3_storage(CONFIG) storage = get_s3_storage_from_config(CONFIG)
for method_frame, properties, body in development_queue_manager._channel.consume( for method_frame, properties, body in development_queue_manager._channel.consume(
queue=CONFIG.response_queue, inactivity_timeout=15 queue=CONFIG.response_queue, inactivity_timeout=15
): ):

View File

@ -1,194 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'pprint.pprint'; 'pprint' is not a package",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [10], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mloader\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FullLoader\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpprint\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpprint\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpp\u001b[39;00m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pprint.pprint'; 'pprint' is not a package"
]
}
],
"source": [
"import pyinfra\n",
"import yaml\n",
"from yaml.loader import FullLoader\n",
"import pprint"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'logging': 0,\n",
" 'mock_analysis_endpoint': 'http://127.0.0.1:5000',\n",
" 'service': {'operations': {'classify': {'input': {'extension': 'cls_in.gz',\n",
" 'multi': True,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'cls_out.gz',\n",
" 'subdir': ''}},\n",
" 'default': {'input': {'extension': 'IN.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'OUT.gz',\n",
" 'subdir': ''}},\n",
" 'extract': {'input': {'extension': 'extr_in.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'gz',\n",
" 'subdir': 'extractions'}},\n",
" 'rotate': {'input': {'extension': 'rot_in.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'rot_out.gz',\n",
" 'subdir': ''}},\n",
" 'stream_pages': {'input': {'extension': 'pgs_in.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'pgs_out.gz',\n",
" 'subdir': 'pages'}},\n",
" 'upper': {'input': {'extension': 'up_in.gz',\n",
" 'multi': False,\n",
" 'subdir': ''},\n",
" 'output': {'extension': 'up_out.gz',\n",
" 'subdir': ''}}},\n",
" 'response_formatter': 'identity'},\n",
" 'storage': {'aws': {'access_key': 'AKIA4QVP6D4LCDAGYGN2',\n",
" 'endpoint': 'https://s3.amazonaws.com',\n",
" 'region': '$STORAGE_REGION|\"eu-west-1\"',\n",
" 'secret_key': '8N6H1TUHTsbvW2qMAm7zZlJ63hMqjcXAsdN7TYED'},\n",
" 'azure': {'connection_string': 'DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net'},\n",
" 'bucket': 'pyinfra-test-bucket',\n",
" 'minio': {'access_key': 'root',\n",
" 'endpoint': 'http://127.0.0.1:9000',\n",
" 'region': None,\n",
" 'secret_key': 'password'}},\n",
" 'use_docker_fixture': 1,\n",
" 'webserver': {'host': '$SERVER_HOST|\"127.0.0.1\"',\n",
" 'mode': '$SERVER_MODE|production',\n",
" 'port': '$SERVER_PORT|5000'}}\n"
]
}
],
"source": [
"\n",
"# Open the file and load the file\n",
"with open('./tests/config.yml') as f:\n",
" data = yaml.load(f, Loader=FullLoader)\n",
" pprint.pprint(data)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"STORAGE_BACKEND\"] = \"azure\"\n",
"\n",
"# always the same\n",
"os.environ[\"STORAGE_BUCKET_NAME\"] = \"pyinfra-test-bucket\"\n",
"\n",
"# s3\n",
"os.environ[\"STORAGE_ENDPOINT\"] = \"https://s3.amazonaws.com\"\n",
"os.environ[\"STORAGE_KEY\"] = \"AKIA4QVP6D4LCDAGYGN2\"\n",
"os.environ[\"STORAGE_SECRET\"] = \"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\"\n",
"os.environ[\"STORAGE_REGION\"] = \"eu-west-1\"\n",
"\n",
"# aks\n",
"os.environ[\"STORAGE_AZURECONNECTIONSTRING\"] = \"DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net\""
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"ename": "Exception",
"evalue": "Unknown storage backend 'aks'.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [23], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m config \u001b[38;5;241m=\u001b[39m pyinfra\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mget_config()\n\u001b[0;32m----> 2\u001b[0m storage \u001b[38;5;241m=\u001b[39m \u001b[43mpyinfra\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstorage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_storage\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/dev/pyinfra/pyinfra/storage/storage.py:15\u001b[0m, in \u001b[0;36mget_storage\u001b[0;34m(config)\u001b[0m\n\u001b[1;32m 13\u001b[0m storage \u001b[39m=\u001b[39m get_azure_storage(config)\n\u001b[1;32m 14\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 15\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mException\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mUnknown storage backend \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mconfig\u001b[39m.\u001b[39mstorage_backend\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 17\u001b[0m \u001b[39mreturn\u001b[39;00m storage\n",
"\u001b[0;31mException\u001b[0m: Unknown storage backend 'aks'."
]
}
],
"source": [
"config = pyinfra.config.get_config()\n",
"storage = pyinfra.storage.get_storage(config)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"storage.has_bucket(config.storage_bucket)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.13 ('pyinfra-TboPpZ8z-py3.8')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "10d7419af5ea6dfec0078ebc9d6fa1a9383fe9894853f90dc7d29a81b3de2c78"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -6,8 +6,7 @@ import pytest
import testcontainers.compose import testcontainers.compose
from pyinfra.config import get_config from pyinfra.config import get_config
from pyinfra.queue.queue_manager import QueueManager from pyinfra.storage import get_storage_from_config
from pyinfra.storage import get_storage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -30,60 +29,35 @@ def docker_compose(sleep_seconds=30):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def storage_config(client_name): def test_storage_config(storage_backend, bucket_name, monitoring_enabled):
config = get_config() config = get_config()
config.storage_backend = client_name config.storage_backend = storage_backend
config.storage_bucket = bucket_name
config.storage_azureconnectionstring = "DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net" config.storage_azureconnectionstring = "DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net"
config.monitoring_enabled = monitoring_enabled
config.prometheus_metric_prefix = "test"
config.prometheus_port = 8080
config.prometheus_host = "0.0.0.0"
return config return config
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def processing_config(storage_config, monitoring_enabled): def test_queue_config():
storage_config.monitoring_enabled = monitoring_enabled
return storage_config
@pytest.fixture(scope="session")
def bucket_name(storage_config):
return storage_config.storage_bucket
@pytest.fixture(scope="session")
def storage(storage_config):
logger.debug("Setup for storage")
storage = get_storage(storage_config)
storage.make_bucket(storage_config.storage_bucket)
storage.clear_bucket(storage_config.storage_bucket)
yield storage
logger.debug("Teardown for storage")
try:
storage.clear_bucket(storage_config.storage_bucket)
except:
pass
@pytest.fixture(scope="session")
def queue_config(payload_processor_type):
config = get_config() config = get_config()
# FIXME: It looks like rabbitmq_heartbeat has to be greater than rabbitmq_connection_sleep. If this is expected, the config.rabbitmq_connection_sleep = 2
# user should not be abele to insert non working values. config.rabbitmq_heartbeat = 4
config.rabbitmq_heartbeat = config.rabbitmq_connection_sleep + 1
return config return config
@pytest.fixture(scope="session")
def queue_manager(queue_config):
queue_manager = QueueManager(queue_config)
return queue_manager
@pytest.fixture @pytest.fixture
def request_payload(): def payload(x_tenant_id):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
return { return {
"dossierId": "test", "dossierId": "test",
"fileId": "test", "fileId": "test",
"targetFileExtension": "json.gz", "targetFileExtension": "json.gz",
"responseFileExtension": "json.gz", "responseFileExtension": "json.gz",
**x_tenant_entry,
} }
@ -93,3 +67,17 @@ def response_payload():
"dossierId": "test", "dossierId": "test",
"fileId": "test", "fileId": "test",
} }
@pytest.fixture(scope="session")
def storage(test_storage_config):
logger.debug("Setup for storage")
storage = get_storage_from_config(test_storage_config)
storage.make_bucket(test_storage_config.storage_bucket)
storage.clear_bucket(test_storage_config.storage_bucket)
yield storage
logger.debug("Teardown for storage")
try:
storage.clear_bucket(test_storage_config.storage_bucket)
except:
pass

View File

@ -4,43 +4,41 @@ import time
import pytest import pytest
import requests import requests
from pyinfra.config import get_config from pyinfra.payload_processing.monitor import PrometheusMonitor
from pyinfra.payload_processing.monitor import get_monitor
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
def monitor_config(): def monitored_mock_function(metric_prefix, host, port):
config = get_config()
config.prometheus_metric_prefix = "monitor_test"
config.prometheus_port = 8000
return config
@pytest.fixture(scope="class")
def prometheus_monitor(monitor_config):
return get_monitor(monitor_config)
@pytest.fixture
def monitored_mock_function(prometheus_monitor):
def process(data=None): def process(data=None):
time.sleep(2) time.sleep(2)
return ["result1", "result2", "result3"] return ["result1", "result2", "result3"]
return prometheus_monitor(process) monitor = PrometheusMonitor(metric_prefix, host, port)
return monitor(process)
@pytest.fixture
def metric_endpoint(host, port):
return f"http://{host}:{port}/prometheus"
@pytest.mark.parametrize("metric_prefix, host, port", [("test", "0.0.0.0", 8000)], scope="class")
class TestPrometheusMonitor: class TestPrometheusMonitor:
def test_prometheus_endpoint_is_available(self, prometheus_monitor, monitor_config): def test_prometheus_endpoint_is_available(self, metric_endpoint, monitored_mock_function):
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus") resp = requests.get(metric_endpoint)
assert resp.status_code == 200 assert resp.status_code == 200
def test_processing_with_a_monitored_fn_increases_parameter_counter(self, monitored_mock_function, monitor_config): def test_processing_with_a_monitored_fn_increases_parameter_counter(
self,
metric_endpoint,
metric_prefix,
monitored_mock_function,
):
monitored_mock_function(data=None) monitored_mock_function(data=None)
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus") resp = requests.get(metric_endpoint)
pattern = re.compile(r".*monitor_test_processing_time_count (\d\.\d).*") pattern = re.compile(rf".*{metric_prefix}_processing_time_count (\d\.\d).*")
assert pattern.search(resp.text).group(1) == "1.0" assert pattern.search(resp.text).group(1) == "1.0"
monitored_mock_function(data=None) monitored_mock_function(data=None)
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus") resp = requests.get(metric_endpoint)
assert pattern.search(resp.text).group(1) == "2.0" assert pattern.search(resp.text).group(1) == "2.0"

View File

@ -0,0 +1,53 @@
import pytest
from pyinfra.payload_processing.payload import (
QueueMessagePayload,
QueueMessagePayloadParser,
)
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
@pytest.fixture
def expected_parsed_payload(x_tenant_id):
return QueueMessagePayload(
dossier_id="test",
file_id="test",
x_tenant_id=x_tenant_id,
target_file_extension="json.gz",
response_file_extension="json.gz",
target_file_type="json",
target_compression_type="gz",
response_file_type="json",
response_compression_type="gz",
target_file_name="test/test.json.gz",
response_file_name="test/test.json.gz",
)
@pytest.fixture
def file_extension_parser(allowed_file_types, allowed_compression_types):
return make_file_extension_parser(allowed_file_types, allowed_compression_types)
@pytest.fixture
def payload_parser(file_extension_parser):
return QueueMessagePayloadParser(file_extension_parser)
@pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])])
class TestPayload:
@pytest.mark.parametrize("x_tenant_id", [None, "klaus"])
def test_payload_is_parsed_correctly(self, payload_parser, payload, expected_parsed_payload):
payload = payload_parser(payload)
assert payload == expected_parsed_payload
@pytest.mark.parametrize(
"extension,expected",
[
("json.gz", ("json", "gz")),
("json", ("json", None)),
("prefix.json.gz", ("json", "gz")),
],
)
def test_parse_file_extension(self, file_extension_parser, extension, expected):
assert file_extension_parser(extension) == expected

View File

@ -8,14 +8,6 @@ import requests
from pyinfra.payload_processing.processor import make_payload_processor from pyinfra.payload_processing.processor import make_payload_processor
@pytest.fixture(scope="session")
def file_processor_mock():
def inner(json_file: dict):
return [json_file]
return inner
@pytest.fixture @pytest.fixture
def target_file(): def target_file():
contents = {"numberOfPages": 10, "content1": "value1", "content2": "value2"} contents = {"numberOfPages": 10, "content1": "value1", "content2": "value2"}
@ -23,54 +15,60 @@ def target_file():
@pytest.fixture @pytest.fixture
def file_names(request_payload): def file_names(payload):
dossier_id, file_id, target_suffix, response_suffix = itemgetter( dossier_id, file_id, target_suffix, response_suffix = itemgetter(
"dossierId", "dossierId",
"fileId", "fileId",
"targetFileExtension", "targetFileExtension",
"responseFileExtension", "responseFileExtension",
)(request_payload) )(payload)
return f"{dossier_id}/{file_id}.{target_suffix}", f"{dossier_id}/{file_id}.{response_suffix}" return f"{dossier_id}/{file_id}.{target_suffix}", f"{dossier_id}/{file_id}.{response_suffix}"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def payload_processor(file_processor_mock, processing_config): def payload_processor(test_storage_config):
yield make_payload_processor(file_processor_mock, processing_config) def file_processor_mock(json_file: dict):
return [json_file]
yield make_payload_processor(file_processor_mock, test_storage_config)
@pytest.mark.parametrize("client_name", ["s3"], scope="session") @pytest.mark.parametrize("storage_backend", ["s3"], scope="session")
@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session")
@pytest.mark.parametrize("monitoring_enabled", [True, False], scope="session") @pytest.mark.parametrize("monitoring_enabled", [True, False], scope="session")
@pytest.mark.parametrize("x_tenant_id", [None])
class TestPayloadProcessor: class TestPayloadProcessor:
def test_payload_processor_yields_correct_response_and_uploads_result( def test_payload_processor_yields_correct_response_and_uploads_result(
self, self,
payload_processor, payload_processor,
storage, storage,
bucket_name, bucket_name,
request_payload, payload,
response_payload, response_payload,
target_file, target_file,
file_names, file_names,
): ):
storage.clear_bucket(bucket_name) storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, file_names[0], target_file) storage.put_object(bucket_name, file_names[0], target_file)
response = payload_processor(request_payload) response = payload_processor(payload)
assert response == response_payload assert response == response_payload
data_received = storage.get_object(bucket_name, file_names[1]) data_received = storage.get_object(bucket_name, file_names[1])
assert json.loads((gzip.decompress(data_received)).decode("utf-8")) == { assert json.loads((gzip.decompress(data_received)).decode("utf-8")) == {
**request_payload, **payload,
"data": [json.loads(gzip.decompress(target_file).decode("utf-8"))], "data": [json.loads(gzip.decompress(target_file).decode("utf-8"))],
} }
def test_catching_of_processing_failure(self, payload_processor, storage, bucket_name, request_payload): def test_catching_of_processing_failure(self, payload_processor, storage, bucket_name, payload):
storage.clear_bucket(bucket_name) storage.clear_bucket(bucket_name)
with pytest.raises(Exception): with pytest.raises(Exception):
payload_processor(request_payload) payload_processor(payload)
def test_prometheus_endpoint_is_available(self, processing_config): def test_prometheus_endpoint_is_available(self, test_storage_config, monitoring_enabled, storage_backend, x_tenant_id):
if monitoring_enabled:
resp = requests.get( resp = requests.get(
f"http://{processing_config.prometheus_host}:{processing_config.prometheus_port}/prometheus" f"http://{test_storage_config.prometheus_host}:{test_storage_config.prometheus_port}/prometheus"
) )
assert resp.status_code == 200 assert resp.status_code == 200

View File

@ -1,44 +0,0 @@
import pytest
from pyinfra.config import get_config
from pyinfra.payload_processing.payload import (
QueueMessagePayload,
get_queue_message_payload_parser,
)
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
@pytest.fixture(scope="session")
def payload_config():
return get_config()
class TestPayload:
def test_payload_is_parsed_correctly(self, request_payload, payload_config):
parse_payload = get_queue_message_payload_parser(payload_config)
payload = parse_payload(request_payload)
assert payload == QueueMessagePayload(
dossier_id="test",
file_id="test",
target_file_extension="json.gz",
response_file_extension="json.gz",
target_file_type="json",
target_compression_type="gz",
response_file_type="json",
response_compression_type="gz",
target_file_name="test/test.json.gz",
response_file_name="test/test.json.gz",
)
@pytest.mark.parametrize(
"extension,expected",
[
("json.gz", ("json", "gz")),
("json", ("json", None)),
("prefix.json.gz", ("json", "gz")),
],
)
@pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])])
def test_parse_file_extension(self, extension, expected, allowed_file_types, allowed_compression_types):
parse = make_file_extension_parser(allowed_file_types, allowed_compression_types)
assert parse(extension) == expected

View File

@ -8,15 +8,16 @@ import pika.exceptions
import pytest import pytest
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
from pyinfra.queue.queue_manager import QueueManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def development_queue_manager(queue_config): def development_queue_manager(test_queue_config):
queue_config.rabbitmq_heartbeat = 7200 test_queue_config.rabbitmq_heartbeat = 7200
development_queue_manager = DevelopmentQueueManager(queue_config) development_queue_manager = DevelopmentQueueManager(test_queue_config)
yield development_queue_manager yield development_queue_manager
logger.info("Tearing down development queue manager...") logger.info("Tearing down development queue manager...")
try: try:
@ -26,10 +27,10 @@ def development_queue_manager(queue_config):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def payload_processing_time(queue_config, offset=5): def payload_processing_time(test_queue_config, offset=5):
# FIXME: this implicitly tests the heartbeat when running the end-to-end test. There should be another way to test # FIXME: this implicitly tests the heartbeat when running the end-to-end test. There should be another way to test
# this explicitly. # this explicitly.
return queue_config.rabbitmq_heartbeat + offset return test_queue_config.rabbitmq_heartbeat + offset
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -48,10 +49,11 @@ def payload_processor(response_payload, payload_processing_time, payload_process
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def start_queue_consumer(queue_manager, payload_processor, sleep_seconds=5): def start_queue_consumer(test_queue_config, payload_processor, sleep_seconds=5):
def consume_queue(): def consume_queue():
queue_manager.start_consuming(payload_processor) queue_manager.start_consuming(payload_processor)
queue_manager = QueueManager(test_queue_config)
p = Process(target=consume_queue) p = Process(target=consume_queue)
p.start() p.start()
logger.info(f"Setting up consumer, waiting for {sleep_seconds}...") logger.info(f"Setting up consumer, waiting for {sleep_seconds}...")
@ -65,39 +67,40 @@ def start_queue_consumer(queue_manager, payload_processor, sleep_seconds=5):
def message_properties(message_headers): def message_properties(message_headers):
if not message_headers: if not message_headers:
return pika.BasicProperties(headers=None) return pika.BasicProperties(headers=None)
elif message_headers == "x-tenant-id": elif message_headers == "X-TENANT-ID":
return pika.BasicProperties(headers={"x-tenant-id": "redaction"}) return pika.BasicProperties(headers={"X-TENANT-ID": "redaction"})
else: else:
raise Exception(f"Invalid {message_headers=}.") raise Exception(f"Invalid {message_headers=}.")
@pytest.mark.parametrize("x_tenant_id", [None])
class TestQueueManager: class TestQueueManager:
# FIXME: All tests here are wonky. This is due to the implementation of running the process-blocking queue_manager # FIXME: All tests here are wonky. This is due to the implementation of running the process-blocking queue_manager
# in a subprocess. It is then very hard to interact directly with the subprocess. If you have a better idea, please # in a subprocess. It is then very hard to interact directly with the subprocess. If you have a better idea, please
# refactor; the tests here are insufficient to ensure the functionality of the queue manager! # refactor; the tests here are insufficient to ensure the functionality of the queue manager!
@pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session") @pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session")
def test_message_processing_does_not_block_heartbeat( def test_message_processing_does_not_block_heartbeat(
self, development_queue_manager, request_payload, response_payload, payload_processing_time self, development_queue_manager, payload, response_payload, payload_processing_time
): ):
development_queue_manager.clear_queues() development_queue_manager.clear_queues()
development_queue_manager.publish_request(request_payload) development_queue_manager.publish_request(payload)
time.sleep(payload_processing_time + 10) time.sleep(payload_processing_time + 10)
_, _, body = development_queue_manager.get_response() _, _, body = development_queue_manager.get_response()
result = json.loads(body) result = json.loads(body)
assert result == response_payload assert result == response_payload
@pytest.mark.parametrize("message_headers", [None, "x-tenant-id"]) @pytest.mark.parametrize("message_headers", [None, "X-TENANT-ID"])
@pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session") @pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session")
def test_queue_manager_forwards_message_headers( def test_queue_manager_forwards_message_headers(
self, self,
development_queue_manager, development_queue_manager,
request_payload, payload,
response_payload, response_payload,
payload_processing_time, payload_processing_time,
message_properties, message_properties,
): ):
development_queue_manager.clear_queues() development_queue_manager.clear_queues()
development_queue_manager.publish_request(request_payload, message_properties) development_queue_manager.publish_request(payload, message_properties)
time.sleep(payload_processing_time + 10) time.sleep(payload_processing_time + 10)
_, properties, _ = development_queue_manager.get_response() _, properties, _ = development_queue_manager.get_response()
assert properties.headers == message_properties.headers assert properties.headers == message_properties.headers
@ -109,12 +112,12 @@ class TestQueueManager:
def test_failed_message_processing_is_handled( def test_failed_message_processing_is_handled(
self, self,
development_queue_manager, development_queue_manager,
request_payload, payload,
response_payload, response_payload,
payload_processing_time, payload_processing_time,
): ):
development_queue_manager.clear_queues() development_queue_manager.clear_queues()
development_queue_manager.publish_request(request_payload) development_queue_manager.publish_request(payload)
time.sleep(payload_processing_time + 10) time.sleep(payload_processing_time + 10)
_, _, body = development_queue_manager.get_response() _, _, body = development_queue_manager.get_response()
assert not body assert not body

View File

@ -6,7 +6,9 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("client_name", ["azure", "s3"], scope="session") @pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="session")
@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session")
@pytest.mark.parametrize("monitoring_enabled", [False], scope="session")
class TestStorage: class TestStorage:
def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name): def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name):
storage.clear_bucket(bucket_name) storage.clear_bucket(bucket_name)