add multi-tenant storage connection 1st iteration
- forward x-tenant-id from queue message header to payload processor - add functions to receive storage infos from an endpoint or the config. This enables hashing and caching of connections created from these infos - add function to initialize storage connections from storage infos - streamline and refactor tests to make them more readable and robust and to make it easier to add new tests - update payload processor with first iteration of multi tenancy storage connection support with connection caching and backwards compability
This commit is contained in:
parent
52c047c47b
commit
d48e8108fd
@ -28,11 +28,9 @@ class Config:
|
||||
"PROMETHEUS_METRIC_PREFIX", "redactmanager_research_service_parameter"
|
||||
)
|
||||
|
||||
# Prometheus webserver address
|
||||
self.prometheus_host = read_from_environment("PROMETHEUS_HOST", "127.0.0.1")
|
||||
|
||||
# Prometheus webserver port
|
||||
self.prometheus_port = int(read_from_environment("PROMETHEUS_PORT", 8080))
|
||||
# Prometheus webserver address and port
|
||||
self.prometheus_host = "0.0.0.0"
|
||||
self.prometheus_port = 8080
|
||||
|
||||
# RabbitMQ host address
|
||||
self.rabbitmq_host = read_from_environment("RABBITMQ_HOST", "localhost")
|
||||
@ -94,6 +92,10 @@ class Config:
|
||||
self.allowed_file_types = ["json", "pdf"]
|
||||
self.allowed_compression_types = ["gz"]
|
||||
|
||||
# config for x-tenant-endpoint to receive storage connection information per tenant
|
||||
self.persistence_service_public_key = "redaction"
|
||||
self.persistence_service_tenant_endpoint = "http://persistence-service-v1:8080/internal-api/tenants"
|
||||
|
||||
# Value to see if we should write a consumer token to a file
|
||||
self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False")
|
||||
|
||||
|
||||
@ -1,2 +1,5 @@
|
||||
class ProcessingFailure(RuntimeError):
|
||||
pass
|
||||
|
||||
class UnknownStorageBackend(Exception):
|
||||
pass
|
||||
@ -12,7 +12,7 @@ logger = logging.getLogger()
|
||||
|
||||
|
||||
class PrometheusMonitor:
|
||||
def __init__(self, prefix: str, port=8080, host="127.0.0.1"):
|
||||
def __init__(self, prefix: str, host: str, port: int):
|
||||
"""Register the monitoring metrics and start a webserver where they can be scraped at the endpoint
|
||||
http://{host}:{port}/prometheus
|
||||
|
||||
@ -23,12 +23,9 @@ class PrometheusMonitor:
|
||||
self.registry = CollectorRegistry()
|
||||
|
||||
self.entity_processing_time_sum = Summary(
|
||||
f"{prefix}_processing_time",
|
||||
"Summed up average processing time per entity observed",
|
||||
f"{prefix}_processing_time", "Summed up average processing time per entity observed", registry=self.registry
|
||||
)
|
||||
|
||||
self.registry.register(self.entity_processing_time_sum)
|
||||
|
||||
start_http_server(port, host, self.registry)
|
||||
|
||||
def __call__(self, process_fn: Callable) -> Callable:
|
||||
@ -58,8 +55,8 @@ class PrometheusMonitor:
|
||||
return inner
|
||||
|
||||
|
||||
def get_monitor(config: Config) -> Callable:
|
||||
def get_monitor_from_config(config: Config) -> Callable:
|
||||
if config.monitoring_enabled:
|
||||
return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_port", "prometheus_host")(config))
|
||||
return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_host", "prometheus_port")(config))
|
||||
else:
|
||||
return identity
|
||||
|
||||
@ -11,6 +11,8 @@ from pyinfra.utils.file_extension_parsing import make_file_extension_parser
|
||||
class QueueMessagePayload:
|
||||
dossier_id: str
|
||||
file_id: str
|
||||
x_tenant_id: Union[str, None]
|
||||
|
||||
target_file_extension: str
|
||||
response_file_extension: str
|
||||
|
||||
@ -35,6 +37,7 @@ class QueueMessagePayloadParser:
|
||||
dossier_id, file_id, target_file_extension, response_file_extension = itemgetter(
|
||||
"dossierId", "fileId", "targetFileExtension", "responseFileExtension"
|
||||
)(payload)
|
||||
x_tenant_id = payload.get("X-TENANT-ID")
|
||||
|
||||
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
|
||||
map(self.parse_file_extensions, [target_file_extension, response_file_extension])
|
||||
@ -46,6 +49,7 @@ class QueueMessagePayloadParser:
|
||||
return QueueMessagePayload(
|
||||
dossier_id=dossier_id,
|
||||
file_id=file_id,
|
||||
x_tenant_id=x_tenant_id,
|
||||
target_file_extension=target_file_extension,
|
||||
response_file_extension=response_file_extension,
|
||||
target_file_type=target_file_type,
|
||||
|
||||
@ -6,16 +6,20 @@ from typing import Callable, Union, List
|
||||
from funcy import compose
|
||||
|
||||
from pyinfra.config import get_config, Config
|
||||
from pyinfra.payload_processing.monitor import get_monitor
|
||||
from pyinfra.payload_processing.monitor import get_monitor_from_config
|
||||
from pyinfra.payload_processing.payload import (
|
||||
QueueMessagePayloadParser,
|
||||
get_queue_message_payload_parser,
|
||||
QueueMessagePayloadFormatter,
|
||||
get_queue_message_payload_formatter,
|
||||
)
|
||||
from pyinfra.storage import get_storage
|
||||
from pyinfra.storage.storage import make_downloader, make_uploader
|
||||
from pyinfra.storage.storages.interface import Storage
|
||||
from pyinfra.storage.storage import make_downloader, make_uploader, get_storage_from_storage_info
|
||||
from pyinfra.storage.storage_info import (
|
||||
AzureStorageInfo,
|
||||
S3StorageInfo,
|
||||
get_storage_info_from_config,
|
||||
get_storage_info_from_endpoint,
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(get_config().logging_level_root)
|
||||
@ -24,8 +28,8 @@ logger.setLevel(get_config().logging_level_root)
|
||||
class PayloadProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
storage: Storage,
|
||||
bucket: str,
|
||||
default_storage_info: Union[AzureStorageInfo, S3StorageInfo],
|
||||
get_storage_info_from_tenant_id,
|
||||
payload_parser: QueueMessagePayloadParser,
|
||||
payload_formatter: QueueMessagePayloadFormatter,
|
||||
data_processor: Callable,
|
||||
@ -33,8 +37,9 @@ class PayloadProcessor:
|
||||
"""Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps.
|
||||
|
||||
Args:
|
||||
storage: The storage to use for downloading and uploading files
|
||||
bucket: The bucket to use for downloading and uploading files
|
||||
default_storage_info: The default storage info used to create the storage connection. This is only used if
|
||||
x_tenant_id is not provided in the queue payload.
|
||||
get_storage_info_from_tenant_id: Callable to acquire storage info from a given tenant id.
|
||||
payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object
|
||||
payload_formatter: Formatter for the storage upload result and the queue message response body
|
||||
data_processor: The analysis function to be called with the downloaded file
|
||||
@ -46,8 +51,10 @@ class PayloadProcessor:
|
||||
self.format_to_queue_message_response_body = payload_formatter.format_to_queue_message_response_body
|
||||
self.process_data = data_processor
|
||||
|
||||
self.make_downloader = partial(make_downloader, storage, bucket)
|
||||
self.make_uploader = partial(make_uploader, storage, bucket)
|
||||
self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id
|
||||
self.default_storage_info = default_storage_info
|
||||
# TODO: use lru-dict
|
||||
self.storages = {}
|
||||
|
||||
def __call__(self, queue_message_payload: dict) -> dict:
|
||||
"""Processes a queue message payload.
|
||||
@ -71,8 +78,16 @@ class PayloadProcessor:
|
||||
payload = self.parse_payload(queue_message_payload)
|
||||
logger.info(f"Processing {asdict(payload)} ...")
|
||||
|
||||
download_file_to_process = self.make_downloader(payload.target_file_type, payload.target_compression_type)
|
||||
upload_processing_result = self.make_uploader(payload.response_file_type, payload.response_compression_type)
|
||||
storage_info = self._get_storage_info(payload.x_tenant_id)
|
||||
bucket = storage_info.bucket_name
|
||||
storage = self._get_storage(storage_info)
|
||||
|
||||
download_file_to_process = make_downloader(
|
||||
storage, bucket, payload.target_file_type, payload.target_compression_type
|
||||
)
|
||||
upload_processing_result = make_uploader(
|
||||
storage, bucket, payload.response_file_type, payload.response_compression_type
|
||||
)
|
||||
format_result_for_storage = partial(self.format_result_for_storage, payload)
|
||||
|
||||
processing_pipeline = compose(format_result_for_storage, self.process_data, download_file_to_process)
|
||||
@ -83,17 +98,40 @@ class PayloadProcessor:
|
||||
|
||||
return self.format_to_queue_message_response_body(payload)
|
||||
|
||||
def _get_storage_info(self, x_tenant_id=None):
|
||||
if x_tenant_id:
|
||||
return self.get_storage_info_from_tenant_id(x_tenant_id)
|
||||
return self.default_storage_info
|
||||
|
||||
def _get_storage(self, storage_info):
|
||||
if storage_info in self.storages:
|
||||
return self.storages[storage_info]
|
||||
else:
|
||||
storage = get_storage_from_storage_info(storage_info)
|
||||
self.storages[storage_info] = storage
|
||||
return storage
|
||||
|
||||
|
||||
def make_payload_processor(data_processor: Callable, config: Union[None, Config] = None) -> PayloadProcessor:
|
||||
"""Produces payload processor for queue manager."""
|
||||
config = config or get_config()
|
||||
|
||||
bucket: str = config.storage_bucket
|
||||
storage: Storage = get_storage(config)
|
||||
monitor = get_monitor(config)
|
||||
default_storage_info: Union[AzureStorageInfo, S3StorageInfo] = get_storage_info_from_config(config)
|
||||
get_storage_info_from_tenant_id = partial(
|
||||
get_storage_info_from_endpoint,
|
||||
config.persistence_service_public_key,
|
||||
config.persistence_service_tenant_endpoint,
|
||||
)
|
||||
monitor = get_monitor_from_config(config)
|
||||
payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config)
|
||||
payload_formatter: QueueMessagePayloadFormatter = get_queue_message_payload_formatter()
|
||||
|
||||
data_processor = monitor(data_processor)
|
||||
|
||||
return PayloadProcessor(storage, bucket, payload_parser, payload_formatter, data_processor)
|
||||
return PayloadProcessor(
|
||||
default_storage_info,
|
||||
get_storage_info_from_tenant_id,
|
||||
payload_parser,
|
||||
payload_formatter,
|
||||
data_processor,
|
||||
)
|
||||
|
||||
@ -12,6 +12,7 @@ from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pyinfra.config import Config
|
||||
from pyinfra.exception import ProcessingFailure
|
||||
from pyinfra.payload_processing.processor import PayloadProcessor
|
||||
from pyinfra.utils.dict import save_project
|
||||
|
||||
CONFIG = Config()
|
||||
|
||||
@ -164,8 +165,8 @@ class QueueManager:
|
||||
except Exception as err:
|
||||
raise ProcessingFailure("QueueMessagePayload processing failed") from err
|
||||
|
||||
def acknowledge_message_and_publish_response(frame, properties, response_body):
|
||||
response_properties = pika.BasicProperties(headers=properties.headers) if properties.headers else None
|
||||
def acknowledge_message_and_publish_response(frame, headers, response_body):
|
||||
response_properties = pika.BasicProperties(headers=headers) if headers else None
|
||||
self._channel.basic_publish("", self._output_queue, json.dumps(response_body).encode(), response_properties)
|
||||
self.logger.info(
|
||||
"Result published, acknowledging incoming message with delivery_tag %s",
|
||||
@ -190,12 +191,15 @@ class QueueManager:
|
||||
|
||||
try:
|
||||
self.logger.debug("Processing (%s, %s, %s)", frame, properties, body)
|
||||
processing_result = process_message_body_and_await_result(json.loads(body))
|
||||
filtered_message_headers = save_project(properties.headers, ["X-TENANT-ID"]) # TODO: parametrize key?
|
||||
message_body = {**json.loads(body), **filtered_message_headers}
|
||||
|
||||
processing_result = process_message_body_and_await_result(message_body)
|
||||
self.logger.info(
|
||||
"Processed message with delivery_tag %s, publishing result to result-queue",
|
||||
frame.delivery_tag,
|
||||
)
|
||||
acknowledge_message_and_publish_response(frame, properties, processing_result)
|
||||
acknowledge_message_and_publish_response(frame, filtered_message_headers, processing_result)
|
||||
|
||||
except ProcessingFailure:
|
||||
self.logger.info(
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from pyinfra.storage.storage import get_storage
|
||||
from pyinfra.storage.storage import get_storage_from_config
|
||||
|
||||
__all__ = ["get_storage"]
|
||||
__all__ = ["get_storage_from_config"]
|
||||
|
||||
@ -1,28 +1,45 @@
|
||||
from functools import lru_cache, partial
|
||||
from typing import Callable
|
||||
from typing import Callable, Union
|
||||
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from funcy import compose
|
||||
from minio import Minio
|
||||
|
||||
from pyinfra.config import Config
|
||||
from pyinfra.storage.storages.azure import get_azure_storage
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage
|
||||
from pyinfra.exception import UnknownStorageBackend
|
||||
from pyinfra.storage.storage_info import AzureStorageInfo, S3StorageInfo, get_storage_info_from_config
|
||||
from pyinfra.storage.storages.azure import AzureStorage
|
||||
from pyinfra.storage.storages.interface import Storage
|
||||
from pyinfra.storage.storages.s3 import S3Storage
|
||||
from pyinfra.utils.compressing import get_decompressor, get_compressor
|
||||
from pyinfra.utils.encoding import get_decoder, get_encoder
|
||||
|
||||
|
||||
def get_storage(config: Config) -> Storage:
|
||||
def get_storage_from_config(config: Config) -> Storage:
|
||||
|
||||
if config.storage_backend == "s3":
|
||||
storage = get_s3_storage(config)
|
||||
elif config.storage_backend == "azure":
|
||||
storage = get_azure_storage(config)
|
||||
else:
|
||||
raise Exception(f"Unknown storage backend '{config.storage_backend}'.")
|
||||
storage_info = get_storage_info_from_config(config)
|
||||
storage = get_storage_from_storage_info(storage_info)
|
||||
|
||||
return storage
|
||||
|
||||
|
||||
def get_storage_from_storage_info(storage_info: Union[AzureStorageInfo, S3StorageInfo]) -> Storage:
|
||||
if isinstance(storage_info, AzureStorageInfo):
|
||||
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string))
|
||||
elif isinstance(storage_info, S3StorageInfo):
|
||||
return S3Storage(
|
||||
Minio(
|
||||
secure=storage_info.secure,
|
||||
endpoint=storage_info.endpoint,
|
||||
access_key=storage_info.access_key,
|
||||
secret_key=storage_info.secret_key,
|
||||
region=storage_info.region,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise UnknownStorageBackend()
|
||||
|
||||
|
||||
def verify_existence(storage: Storage, bucket: str, file_name: str) -> str:
|
||||
if not storage.exists(bucket, file_name):
|
||||
raise FileNotFoundError(f"{file_name=} name not found on storage in {bucket=}.")
|
||||
|
||||
89
pyinfra/storage/storage_info.py
Normal file
89
pyinfra/storage/storage_info.py
Normal file
@ -0,0 +1,89 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
import requests
|
||||
|
||||
from pyinfra.config import Config
|
||||
from pyinfra.exception import UnknownStorageBackend
|
||||
from pyinfra.utils.cipher import decrypt
|
||||
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AzureStorageInfo:
|
||||
connection_string: str
|
||||
bucket_name: str
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.connection_string)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class S3StorageInfo:
|
||||
secure: bool
|
||||
endpoint: str
|
||||
access_key: str
|
||||
secret_key: str
|
||||
region: str
|
||||
bucket_name: str
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.secure, self.endpoint, self.access_key, self.secret_key, self.region))
|
||||
|
||||
|
||||
def get_storage_info_from_endpoint(
|
||||
public_key: str, endpoint: str, x_tenant_id: str
|
||||
) -> Union[AzureStorageInfo, S3StorageInfo]:
|
||||
# FIXME: parameterize port, host and public_key
|
||||
public_key = "redaction"
|
||||
resp = requests.get(f"{endpoint}/{x_tenant_id}").json()
|
||||
|
||||
maybe_azure = resp.get("azureStorageConnection")
|
||||
maybe_s3 = resp.get("azureStorageConnection")
|
||||
assert not (maybe_azure and maybe_s3)
|
||||
|
||||
if maybe_azure:
|
||||
connection_string = decrypt(public_key, maybe_azure["connectionString"])
|
||||
storage_info = AzureStorageInfo(
|
||||
connection_string=connection_string,
|
||||
bucket_name=maybe_azure["containerName"],
|
||||
)
|
||||
elif maybe_s3:
|
||||
secure, endpoint = validate_and_parse_s3_endpoint(maybe_s3["endpoint"])
|
||||
secret = decrypt(public_key, maybe_s3["secret"])
|
||||
|
||||
storage_info = S3StorageInfo(
|
||||
secure=secure,
|
||||
endpoint=endpoint,
|
||||
access_key=maybe_s3["key"],
|
||||
secret_key=secret,
|
||||
region=maybe_s3["region"],
|
||||
bucket_name=maybe_s3,
|
||||
)
|
||||
else:
|
||||
raise UnknownStorageBackend()
|
||||
|
||||
return storage_info
|
||||
|
||||
|
||||
def get_storage_info_from_config(config: Config) -> Union[AzureStorageInfo, S3StorageInfo]:
|
||||
if config.storage_backend == "s3":
|
||||
storage_info = S3StorageInfo(
|
||||
secure=config.storage_secure_connection,
|
||||
endpoint=config.storage_endpoint,
|
||||
access_key=config.storage_key,
|
||||
secret_key=config.storage_secret,
|
||||
region=config.storage_region,
|
||||
bucket_name=config.storage_bucket,
|
||||
)
|
||||
|
||||
elif config.storage_backend == "azure":
|
||||
storage_info = AzureStorageInfo(
|
||||
connection_string=config.storage_azureconnectionstring,
|
||||
bucket_name=config.storage_bucket,
|
||||
)
|
||||
|
||||
else:
|
||||
raise UnknownStorageBackend(f"Unknown storage backend '{config.storage_backend}'.")
|
||||
|
||||
return storage_info
|
||||
@ -77,5 +77,5 @@ class AzureStorage(Storage):
|
||||
return zip(repeat(bucket_name), map(attrgetter("name"), blobs))
|
||||
|
||||
|
||||
def get_azure_storage(config: Config):
|
||||
def get_azure_storage_from_config(config: Config):
|
||||
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=config.storage_azureconnectionstring))
|
||||
|
||||
@ -67,7 +67,7 @@ class S3Storage(Storage):
|
||||
return zip(repeat(bucket_name), map(attrgetter("object_name"), objs))
|
||||
|
||||
|
||||
def get_s3_storage(config: Config):
|
||||
def get_s3_storage_from_config(config: Config):
|
||||
return S3Storage(
|
||||
Minio(
|
||||
secure=config.storage_secure_connection,
|
||||
|
||||
5
pyinfra/utils/dict.py
Normal file
5
pyinfra/utils/dict.py
Normal file
@ -0,0 +1,5 @@
|
||||
from funcy import project
|
||||
|
||||
|
||||
def save_project(mapping, keys) -> dict:
|
||||
return project(mapping, keys) if mapping else {}
|
||||
@ -7,7 +7,7 @@ import pika
|
||||
|
||||
from pyinfra.config import get_config
|
||||
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage_from_config
|
||||
|
||||
CONFIG = get_config()
|
||||
logging.basicConfig()
|
||||
@ -26,7 +26,7 @@ def upload_json_and_make_message_body():
|
||||
object_name = f"{dossier_id}/{file_id}.{suffix}"
|
||||
data = gzip.compress(json.dumps(content).encode("utf-8"))
|
||||
|
||||
storage = get_s3_storage(CONFIG)
|
||||
storage = get_s3_storage_from_config(CONFIG)
|
||||
if not storage.has_bucket(bucket):
|
||||
storage.make_bucket(bucket)
|
||||
storage.put_object(bucket, object_name, data)
|
||||
@ -46,10 +46,10 @@ def main():
|
||||
|
||||
message = upload_json_and_make_message_body()
|
||||
|
||||
development_queue_manager.publish_request(message, pika.BasicProperties(headers={"x-tenant-id": "redaction"}))
|
||||
development_queue_manager.publish_request(message, pika.BasicProperties(headers={"X-TENANT-ID": "redaction"}))
|
||||
logger.info(f"Put {message} on {CONFIG.request_queue}")
|
||||
|
||||
storage = get_s3_storage(CONFIG)
|
||||
storage = get_s3_storage_from_config(CONFIG)
|
||||
for method_frame, properties, body in development_queue_manager._channel.consume(
|
||||
queue=CONFIG.response_queue, inactivity_timeout=15
|
||||
):
|
||||
|
||||
194
test.ipynb
194
test.ipynb
@ -1,194 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "ModuleNotFoundError",
|
||||
"evalue": "No module named 'pprint.pprint'; 'pprint' is not a package",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn [10], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mloader\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FullLoader\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpprint\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpprint\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpp\u001b[39;00m\n",
|
||||
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pprint.pprint'; 'pprint' is not a package"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pyinfra\n",
|
||||
"import yaml\n",
|
||||
"from yaml.loader import FullLoader\n",
|
||||
"import pprint"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'logging': 0,\n",
|
||||
" 'mock_analysis_endpoint': 'http://127.0.0.1:5000',\n",
|
||||
" 'service': {'operations': {'classify': {'input': {'extension': 'cls_in.gz',\n",
|
||||
" 'multi': True,\n",
|
||||
" 'subdir': ''},\n",
|
||||
" 'output': {'extension': 'cls_out.gz',\n",
|
||||
" 'subdir': ''}},\n",
|
||||
" 'default': {'input': {'extension': 'IN.gz',\n",
|
||||
" 'multi': False,\n",
|
||||
" 'subdir': ''},\n",
|
||||
" 'output': {'extension': 'OUT.gz',\n",
|
||||
" 'subdir': ''}},\n",
|
||||
" 'extract': {'input': {'extension': 'extr_in.gz',\n",
|
||||
" 'multi': False,\n",
|
||||
" 'subdir': ''},\n",
|
||||
" 'output': {'extension': 'gz',\n",
|
||||
" 'subdir': 'extractions'}},\n",
|
||||
" 'rotate': {'input': {'extension': 'rot_in.gz',\n",
|
||||
" 'multi': False,\n",
|
||||
" 'subdir': ''},\n",
|
||||
" 'output': {'extension': 'rot_out.gz',\n",
|
||||
" 'subdir': ''}},\n",
|
||||
" 'stream_pages': {'input': {'extension': 'pgs_in.gz',\n",
|
||||
" 'multi': False,\n",
|
||||
" 'subdir': ''},\n",
|
||||
" 'output': {'extension': 'pgs_out.gz',\n",
|
||||
" 'subdir': 'pages'}},\n",
|
||||
" 'upper': {'input': {'extension': 'up_in.gz',\n",
|
||||
" 'multi': False,\n",
|
||||
" 'subdir': ''},\n",
|
||||
" 'output': {'extension': 'up_out.gz',\n",
|
||||
" 'subdir': ''}}},\n",
|
||||
" 'response_formatter': 'identity'},\n",
|
||||
" 'storage': {'aws': {'access_key': 'AKIA4QVP6D4LCDAGYGN2',\n",
|
||||
" 'endpoint': 'https://s3.amazonaws.com',\n",
|
||||
" 'region': '$STORAGE_REGION|\"eu-west-1\"',\n",
|
||||
" 'secret_key': '8N6H1TUHTsbvW2qMAm7zZlJ63hMqjcXAsdN7TYED'},\n",
|
||||
" 'azure': {'connection_string': 'DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net'},\n",
|
||||
" 'bucket': 'pyinfra-test-bucket',\n",
|
||||
" 'minio': {'access_key': 'root',\n",
|
||||
" 'endpoint': 'http://127.0.0.1:9000',\n",
|
||||
" 'region': None,\n",
|
||||
" 'secret_key': 'password'}},\n",
|
||||
" 'use_docker_fixture': 1,\n",
|
||||
" 'webserver': {'host': '$SERVER_HOST|\"127.0.0.1\"',\n",
|
||||
" 'mode': '$SERVER_MODE|production',\n",
|
||||
" 'port': '$SERVER_PORT|5000'}}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"# Open the file and load the file\n",
|
||||
"with open('./tests/config.yml') as f:\n",
|
||||
" data = yaml.load(f, Loader=FullLoader)\n",
|
||||
" pprint.pprint(data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.environ[\"STORAGE_BACKEND\"] = \"azure\"\n",
|
||||
"\n",
|
||||
"# always the same\n",
|
||||
"os.environ[\"STORAGE_BUCKET_NAME\"] = \"pyinfra-test-bucket\"\n",
|
||||
"\n",
|
||||
"# s3\n",
|
||||
"os.environ[\"STORAGE_ENDPOINT\"] = \"https://s3.amazonaws.com\"\n",
|
||||
"os.environ[\"STORAGE_KEY\"] = \"AKIA4QVP6D4LCDAGYGN2\"\n",
|
||||
"os.environ[\"STORAGE_SECRET\"] = \"8N6H1TUHTsbvW2qMAm7zZlJ63hMqjcXAsdN7TYED\"\n",
|
||||
"os.environ[\"STORAGE_REGION\"] = \"eu-west-1\"\n",
|
||||
"\n",
|
||||
"# aks\n",
|
||||
"os.environ[\"STORAGE_AZURECONNECTIONSTRING\"] = \"DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "Exception",
|
||||
"evalue": "Unknown storage backend 'aks'.",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn [23], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m config \u001b[38;5;241m=\u001b[39m pyinfra\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mget_config()\n\u001b[0;32m----> 2\u001b[0m storage \u001b[38;5;241m=\u001b[39m \u001b[43mpyinfra\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstorage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_storage\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/dev/pyinfra/pyinfra/storage/storage.py:15\u001b[0m, in \u001b[0;36mget_storage\u001b[0;34m(config)\u001b[0m\n\u001b[1;32m 13\u001b[0m storage \u001b[39m=\u001b[39m get_azure_storage(config)\n\u001b[1;32m 14\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 15\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mException\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mUnknown storage backend \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mconfig\u001b[39m.\u001b[39mstorage_backend\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 17\u001b[0m \u001b[39mreturn\u001b[39;00m storage\n",
|
||||
"\u001b[0;31mException\u001b[0m: Unknown storage backend 'aks'."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"config = pyinfra.config.get_config()\n",
|
||||
"storage = pyinfra.storage.get_storage(config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"False"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"storage.has_bucket(config.storage_bucket)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.8.13 ('pyinfra-TboPpZ8z-py3.8')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.13"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "10d7419af5ea6dfec0078ebc9d6fa1a9383fe9894853f90dc7d29a81b3de2c78"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -6,8 +6,7 @@ import pytest
|
||||
import testcontainers.compose
|
||||
|
||||
from pyinfra.config import get_config
|
||||
from pyinfra.queue.queue_manager import QueueManager
|
||||
from pyinfra.storage import get_storage
|
||||
from pyinfra.storage import get_storage_from_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
@ -30,60 +29,35 @@ def docker_compose(sleep_seconds=30):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def storage_config(client_name):
|
||||
def test_storage_config(storage_backend, bucket_name, monitoring_enabled):
|
||||
config = get_config()
|
||||
config.storage_backend = client_name
|
||||
config.storage_backend = storage_backend
|
||||
config.storage_bucket = bucket_name
|
||||
config.storage_azureconnectionstring = "DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net"
|
||||
config.monitoring_enabled = monitoring_enabled
|
||||
config.prometheus_metric_prefix = "test"
|
||||
config.prometheus_port = 8080
|
||||
config.prometheus_host = "0.0.0.0"
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def processing_config(storage_config, monitoring_enabled):
|
||||
storage_config.monitoring_enabled = monitoring_enabled
|
||||
return storage_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def bucket_name(storage_config):
|
||||
return storage_config.storage_bucket
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def storage(storage_config):
|
||||
logger.debug("Setup for storage")
|
||||
storage = get_storage(storage_config)
|
||||
storage.make_bucket(storage_config.storage_bucket)
|
||||
storage.clear_bucket(storage_config.storage_bucket)
|
||||
yield storage
|
||||
logger.debug("Teardown for storage")
|
||||
try:
|
||||
storage.clear_bucket(storage_config.storage_bucket)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def queue_config(payload_processor_type):
|
||||
def test_queue_config():
|
||||
config = get_config()
|
||||
# FIXME: It looks like rabbitmq_heartbeat has to be greater than rabbitmq_connection_sleep. If this is expected, the
|
||||
# user should not be abele to insert non working values.
|
||||
config.rabbitmq_heartbeat = config.rabbitmq_connection_sleep + 1
|
||||
config.rabbitmq_connection_sleep = 2
|
||||
config.rabbitmq_heartbeat = 4
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def queue_manager(queue_config):
|
||||
queue_manager = QueueManager(queue_config)
|
||||
return queue_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def request_payload():
|
||||
def payload(x_tenant_id):
|
||||
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
|
||||
return {
|
||||
"dossierId": "test",
|
||||
"fileId": "test",
|
||||
"targetFileExtension": "json.gz",
|
||||
"responseFileExtension": "json.gz",
|
||||
**x_tenant_entry,
|
||||
}
|
||||
|
||||
|
||||
@ -93,3 +67,17 @@ def response_payload():
|
||||
"dossierId": "test",
|
||||
"fileId": "test",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def storage(test_storage_config):
|
||||
logger.debug("Setup for storage")
|
||||
storage = get_storage_from_config(test_storage_config)
|
||||
storage.make_bucket(test_storage_config.storage_bucket)
|
||||
storage.clear_bucket(test_storage_config.storage_bucket)
|
||||
yield storage
|
||||
logger.debug("Teardown for storage")
|
||||
try:
|
||||
storage.clear_bucket(test_storage_config.storage_bucket)
|
||||
except:
|
||||
pass
|
||||
|
||||
@ -4,43 +4,41 @@ import time
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from pyinfra.config import get_config
|
||||
from pyinfra.payload_processing.monitor import get_monitor
|
||||
from pyinfra.payload_processing.monitor import PrometheusMonitor
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def monitor_config():
|
||||
config = get_config()
|
||||
config.prometheus_metric_prefix = "monitor_test"
|
||||
config.prometheus_port = 8000
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def prometheus_monitor(monitor_config):
|
||||
return get_monitor(monitor_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def monitored_mock_function(prometheus_monitor):
|
||||
def monitored_mock_function(metric_prefix, host, port):
|
||||
def process(data=None):
|
||||
time.sleep(2)
|
||||
return ["result1", "result2", "result3"]
|
||||
|
||||
return prometheus_monitor(process)
|
||||
monitor = PrometheusMonitor(metric_prefix, host, port)
|
||||
return monitor(process)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metric_endpoint(host, port):
|
||||
return f"http://{host}:{port}/prometheus"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("metric_prefix, host, port", [("test", "0.0.0.0", 8000)], scope="class")
|
||||
class TestPrometheusMonitor:
|
||||
def test_prometheus_endpoint_is_available(self, prometheus_monitor, monitor_config):
|
||||
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus")
|
||||
def test_prometheus_endpoint_is_available(self, metric_endpoint, monitored_mock_function):
|
||||
resp = requests.get(metric_endpoint)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_processing_with_a_monitored_fn_increases_parameter_counter(self, monitored_mock_function, monitor_config):
|
||||
def test_processing_with_a_monitored_fn_increases_parameter_counter(
|
||||
self,
|
||||
metric_endpoint,
|
||||
metric_prefix,
|
||||
monitored_mock_function,
|
||||
):
|
||||
monitored_mock_function(data=None)
|
||||
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus")
|
||||
pattern = re.compile(r".*monitor_test_processing_time_count (\d\.\d).*")
|
||||
resp = requests.get(metric_endpoint)
|
||||
pattern = re.compile(rf".*{metric_prefix}_processing_time_count (\d\.\d).*")
|
||||
assert pattern.search(resp.text).group(1) == "1.0"
|
||||
|
||||
monitored_mock_function(data=None)
|
||||
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus")
|
||||
resp = requests.get(metric_endpoint)
|
||||
assert pattern.search(resp.text).group(1) == "2.0"
|
||||
|
||||
53
tests/payload_parsing_test.py
Normal file
53
tests/payload_parsing_test.py
Normal file
@ -0,0 +1,53 @@
|
||||
import pytest
|
||||
|
||||
from pyinfra.payload_processing.payload import (
|
||||
QueueMessagePayload,
|
||||
QueueMessagePayloadParser,
|
||||
)
|
||||
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_parsed_payload(x_tenant_id):
|
||||
return QueueMessagePayload(
|
||||
dossier_id="test",
|
||||
file_id="test",
|
||||
x_tenant_id=x_tenant_id,
|
||||
target_file_extension="json.gz",
|
||||
response_file_extension="json.gz",
|
||||
target_file_type="json",
|
||||
target_compression_type="gz",
|
||||
response_file_type="json",
|
||||
response_compression_type="gz",
|
||||
target_file_name="test/test.json.gz",
|
||||
response_file_name="test/test.json.gz",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_extension_parser(allowed_file_types, allowed_compression_types):
|
||||
return make_file_extension_parser(allowed_file_types, allowed_compression_types)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload_parser(file_extension_parser):
|
||||
return QueueMessagePayloadParser(file_extension_parser)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])])
|
||||
class TestPayload:
|
||||
@pytest.mark.parametrize("x_tenant_id", [None, "klaus"])
|
||||
def test_payload_is_parsed_correctly(self, payload_parser, payload, expected_parsed_payload):
|
||||
payload = payload_parser(payload)
|
||||
assert payload == expected_parsed_payload
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extension,expected",
|
||||
[
|
||||
("json.gz", ("json", "gz")),
|
||||
("json", ("json", None)),
|
||||
("prefix.json.gz", ("json", "gz")),
|
||||
],
|
||||
)
|
||||
def test_parse_file_extension(self, file_extension_parser, extension, expected):
|
||||
assert file_extension_parser(extension) == expected
|
||||
@ -8,14 +8,6 @@ import requests
|
||||
from pyinfra.payload_processing.processor import make_payload_processor
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def file_processor_mock():
|
||||
def inner(json_file: dict):
|
||||
return [json_file]
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_file():
|
||||
contents = {"numberOfPages": 10, "content1": "value1", "content2": "value2"}
|
||||
@ -23,54 +15,60 @@ def target_file():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_names(request_payload):
|
||||
def file_names(payload):
|
||||
dossier_id, file_id, target_suffix, response_suffix = itemgetter(
|
||||
"dossierId",
|
||||
"fileId",
|
||||
"targetFileExtension",
|
||||
"responseFileExtension",
|
||||
)(request_payload)
|
||||
)(payload)
|
||||
return f"{dossier_id}/{file_id}.{target_suffix}", f"{dossier_id}/{file_id}.{response_suffix}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def payload_processor(file_processor_mock, processing_config):
|
||||
yield make_payload_processor(file_processor_mock, processing_config)
|
||||
def payload_processor(test_storage_config):
|
||||
def file_processor_mock(json_file: dict):
|
||||
return [json_file]
|
||||
|
||||
yield make_payload_processor(file_processor_mock, test_storage_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("client_name", ["s3"], scope="session")
|
||||
@pytest.mark.parametrize("storage_backend", ["s3"], scope="session")
|
||||
@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session")
|
||||
@pytest.mark.parametrize("monitoring_enabled", [True, False], scope="session")
|
||||
@pytest.mark.parametrize("x_tenant_id", [None])
|
||||
class TestPayloadProcessor:
|
||||
def test_payload_processor_yields_correct_response_and_uploads_result(
|
||||
self,
|
||||
payload_processor,
|
||||
storage,
|
||||
bucket_name,
|
||||
request_payload,
|
||||
payload,
|
||||
response_payload,
|
||||
target_file,
|
||||
file_names,
|
||||
):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(bucket_name, file_names[0], target_file)
|
||||
response = payload_processor(request_payload)
|
||||
response = payload_processor(payload)
|
||||
|
||||
assert response == response_payload
|
||||
|
||||
data_received = storage.get_object(bucket_name, file_names[1])
|
||||
|
||||
assert json.loads((gzip.decompress(data_received)).decode("utf-8")) == {
|
||||
**request_payload,
|
||||
**payload,
|
||||
"data": [json.loads(gzip.decompress(target_file).decode("utf-8"))],
|
||||
}
|
||||
|
||||
def test_catching_of_processing_failure(self, payload_processor, storage, bucket_name, request_payload):
|
||||
def test_catching_of_processing_failure(self, payload_processor, storage, bucket_name, payload):
|
||||
storage.clear_bucket(bucket_name)
|
||||
with pytest.raises(Exception):
|
||||
payload_processor(request_payload)
|
||||
payload_processor(payload)
|
||||
|
||||
def test_prometheus_endpoint_is_available(self, processing_config):
|
||||
resp = requests.get(
|
||||
f"http://{processing_config.prometheus_host}:{processing_config.prometheus_port}/prometheus"
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
def test_prometheus_endpoint_is_available(self, test_storage_config, monitoring_enabled, storage_backend, x_tenant_id):
|
||||
if monitoring_enabled:
|
||||
resp = requests.get(
|
||||
f"http://{test_storage_config.prometheus_host}:{test_storage_config.prometheus_port}/prometheus"
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
@ -1,44 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from pyinfra.config import get_config
|
||||
from pyinfra.payload_processing.payload import (
|
||||
QueueMessagePayload,
|
||||
get_queue_message_payload_parser,
|
||||
)
|
||||
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def payload_config():
|
||||
return get_config()
|
||||
|
||||
|
||||
class TestPayload:
|
||||
def test_payload_is_parsed_correctly(self, request_payload, payload_config):
|
||||
parse_payload = get_queue_message_payload_parser(payload_config)
|
||||
payload = parse_payload(request_payload)
|
||||
assert payload == QueueMessagePayload(
|
||||
dossier_id="test",
|
||||
file_id="test",
|
||||
target_file_extension="json.gz",
|
||||
response_file_extension="json.gz",
|
||||
target_file_type="json",
|
||||
target_compression_type="gz",
|
||||
response_file_type="json",
|
||||
response_compression_type="gz",
|
||||
target_file_name="test/test.json.gz",
|
||||
response_file_name="test/test.json.gz",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extension,expected",
|
||||
[
|
||||
("json.gz", ("json", "gz")),
|
||||
("json", ("json", None)),
|
||||
("prefix.json.gz", ("json", "gz")),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])])
|
||||
def test_parse_file_extension(self, extension, expected, allowed_file_types, allowed_compression_types):
|
||||
parse = make_file_extension_parser(allowed_file_types, allowed_compression_types)
|
||||
assert parse(extension) == expected
|
||||
@ -8,15 +8,16 @@ import pika.exceptions
|
||||
import pytest
|
||||
|
||||
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
|
||||
from pyinfra.queue.queue_manager import QueueManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def development_queue_manager(queue_config):
|
||||
queue_config.rabbitmq_heartbeat = 7200
|
||||
development_queue_manager = DevelopmentQueueManager(queue_config)
|
||||
def development_queue_manager(test_queue_config):
|
||||
test_queue_config.rabbitmq_heartbeat = 7200
|
||||
development_queue_manager = DevelopmentQueueManager(test_queue_config)
|
||||
yield development_queue_manager
|
||||
logger.info("Tearing down development queue manager...")
|
||||
try:
|
||||
@ -26,10 +27,10 @@ def development_queue_manager(queue_config):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def payload_processing_time(queue_config, offset=5):
|
||||
def payload_processing_time(test_queue_config, offset=5):
|
||||
# FIXME: this implicitly tests the heartbeat when running the end-to-end test. There should be another way to test
|
||||
# this explicitly.
|
||||
return queue_config.rabbitmq_heartbeat + offset
|
||||
return test_queue_config.rabbitmq_heartbeat + offset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -48,10 +49,11 @@ def payload_processor(response_payload, payload_processing_time, payload_process
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def start_queue_consumer(queue_manager, payload_processor, sleep_seconds=5):
|
||||
def start_queue_consumer(test_queue_config, payload_processor, sleep_seconds=5):
|
||||
def consume_queue():
|
||||
queue_manager.start_consuming(payload_processor)
|
||||
|
||||
queue_manager = QueueManager(test_queue_config)
|
||||
p = Process(target=consume_queue)
|
||||
p.start()
|
||||
logger.info(f"Setting up consumer, waiting for {sleep_seconds}...")
|
||||
@ -65,39 +67,40 @@ def start_queue_consumer(queue_manager, payload_processor, sleep_seconds=5):
|
||||
def message_properties(message_headers):
|
||||
if not message_headers:
|
||||
return pika.BasicProperties(headers=None)
|
||||
elif message_headers == "x-tenant-id":
|
||||
return pika.BasicProperties(headers={"x-tenant-id": "redaction"})
|
||||
elif message_headers == "X-TENANT-ID":
|
||||
return pika.BasicProperties(headers={"X-TENANT-ID": "redaction"})
|
||||
else:
|
||||
raise Exception(f"Invalid {message_headers=}.")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x_tenant_id", [None])
|
||||
class TestQueueManager:
|
||||
# FIXME: All tests here are wonky. This is due to the implementation of running the process-blocking queue_manager
|
||||
# in a subprocess. It is then very hard to interact directly with the subprocess. If you have a better idea, please
|
||||
# refactor; the tests here are insufficient to ensure the functionality of the queue manager!
|
||||
@pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session")
|
||||
def test_message_processing_does_not_block_heartbeat(
|
||||
self, development_queue_manager, request_payload, response_payload, payload_processing_time
|
||||
self, development_queue_manager, payload, response_payload, payload_processing_time
|
||||
):
|
||||
development_queue_manager.clear_queues()
|
||||
development_queue_manager.publish_request(request_payload)
|
||||
development_queue_manager.publish_request(payload)
|
||||
time.sleep(payload_processing_time + 10)
|
||||
_, _, body = development_queue_manager.get_response()
|
||||
result = json.loads(body)
|
||||
assert result == response_payload
|
||||
|
||||
@pytest.mark.parametrize("message_headers", [None, "x-tenant-id"])
|
||||
@pytest.mark.parametrize("message_headers", [None, "X-TENANT-ID"])
|
||||
@pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session")
|
||||
def test_queue_manager_forwards_message_headers(
|
||||
self,
|
||||
development_queue_manager,
|
||||
request_payload,
|
||||
payload,
|
||||
response_payload,
|
||||
payload_processing_time,
|
||||
message_properties,
|
||||
):
|
||||
development_queue_manager.clear_queues()
|
||||
development_queue_manager.publish_request(request_payload, message_properties)
|
||||
development_queue_manager.publish_request(payload, message_properties)
|
||||
time.sleep(payload_processing_time + 10)
|
||||
_, properties, _ = development_queue_manager.get_response()
|
||||
assert properties.headers == message_properties.headers
|
||||
@ -109,12 +112,12 @@ class TestQueueManager:
|
||||
def test_failed_message_processing_is_handled(
|
||||
self,
|
||||
development_queue_manager,
|
||||
request_payload,
|
||||
payload,
|
||||
response_payload,
|
||||
payload_processing_time,
|
||||
):
|
||||
development_queue_manager.clear_queues()
|
||||
development_queue_manager.publish_request(request_payload)
|
||||
development_queue_manager.publish_request(payload)
|
||||
time.sleep(payload_processing_time + 10)
|
||||
_, _, body = development_queue_manager.get_response()
|
||||
assert not body
|
||||
@ -6,7 +6,9 @@ logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("client_name", ["azure", "s3"], scope="session")
|
||||
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="session")
|
||||
@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session")
|
||||
@pytest.mark.parametrize("monitoring_enabled", [False], scope="session")
|
||||
class TestStorage:
|
||||
def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user