add multi-tenant storage connection 1st iteration
- forward x-tenant-id from queue message header to payload processor - add functions to receive storage infos from an endpoint or the config. This enables hashing and caching of connections created from these infos - add function to initialize storage connections from storage infos - streamline and refactor tests to make them more readable and robust and to make it easier to add new tests - update payload processor with first iteration of multi tenancy storage connection support with connection caching and backwards compability
This commit is contained in:
parent
52c047c47b
commit
d48e8108fd
@ -28,11 +28,9 @@ class Config:
|
|||||||
"PROMETHEUS_METRIC_PREFIX", "redactmanager_research_service_parameter"
|
"PROMETHEUS_METRIC_PREFIX", "redactmanager_research_service_parameter"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prometheus webserver address
|
# Prometheus webserver address and port
|
||||||
self.prometheus_host = read_from_environment("PROMETHEUS_HOST", "127.0.0.1")
|
self.prometheus_host = "0.0.0.0"
|
||||||
|
self.prometheus_port = 8080
|
||||||
# Prometheus webserver port
|
|
||||||
self.prometheus_port = int(read_from_environment("PROMETHEUS_PORT", 8080))
|
|
||||||
|
|
||||||
# RabbitMQ host address
|
# RabbitMQ host address
|
||||||
self.rabbitmq_host = read_from_environment("RABBITMQ_HOST", "localhost")
|
self.rabbitmq_host = read_from_environment("RABBITMQ_HOST", "localhost")
|
||||||
@ -94,6 +92,10 @@ class Config:
|
|||||||
self.allowed_file_types = ["json", "pdf"]
|
self.allowed_file_types = ["json", "pdf"]
|
||||||
self.allowed_compression_types = ["gz"]
|
self.allowed_compression_types = ["gz"]
|
||||||
|
|
||||||
|
# config for x-tenant-endpoint to receive storage connection information per tenant
|
||||||
|
self.persistence_service_public_key = "redaction"
|
||||||
|
self.persistence_service_tenant_endpoint = "http://persistence-service-v1:8080/internal-api/tenants"
|
||||||
|
|
||||||
# Value to see if we should write a consumer token to a file
|
# Value to see if we should write a consumer token to a file
|
||||||
self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False")
|
self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False")
|
||||||
|
|
||||||
|
|||||||
@ -1,2 +1,5 @@
|
|||||||
class ProcessingFailure(RuntimeError):
|
class ProcessingFailure(RuntimeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class UnknownStorageBackend(Exception):
|
||||||
|
pass
|
||||||
@ -12,7 +12,7 @@ logger = logging.getLogger()
|
|||||||
|
|
||||||
|
|
||||||
class PrometheusMonitor:
|
class PrometheusMonitor:
|
||||||
def __init__(self, prefix: str, port=8080, host="127.0.0.1"):
|
def __init__(self, prefix: str, host: str, port: int):
|
||||||
"""Register the monitoring metrics and start a webserver where they can be scraped at the endpoint
|
"""Register the monitoring metrics and start a webserver where they can be scraped at the endpoint
|
||||||
http://{host}:{port}/prometheus
|
http://{host}:{port}/prometheus
|
||||||
|
|
||||||
@ -23,12 +23,9 @@ class PrometheusMonitor:
|
|||||||
self.registry = CollectorRegistry()
|
self.registry = CollectorRegistry()
|
||||||
|
|
||||||
self.entity_processing_time_sum = Summary(
|
self.entity_processing_time_sum = Summary(
|
||||||
f"{prefix}_processing_time",
|
f"{prefix}_processing_time", "Summed up average processing time per entity observed", registry=self.registry
|
||||||
"Summed up average processing time per entity observed",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.registry.register(self.entity_processing_time_sum)
|
|
||||||
|
|
||||||
start_http_server(port, host, self.registry)
|
start_http_server(port, host, self.registry)
|
||||||
|
|
||||||
def __call__(self, process_fn: Callable) -> Callable:
|
def __call__(self, process_fn: Callable) -> Callable:
|
||||||
@ -58,8 +55,8 @@ class PrometheusMonitor:
|
|||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
def get_monitor(config: Config) -> Callable:
|
def get_monitor_from_config(config: Config) -> Callable:
|
||||||
if config.monitoring_enabled:
|
if config.monitoring_enabled:
|
||||||
return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_port", "prometheus_host")(config))
|
return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_host", "prometheus_port")(config))
|
||||||
else:
|
else:
|
||||||
return identity
|
return identity
|
||||||
|
|||||||
@ -11,6 +11,8 @@ from pyinfra.utils.file_extension_parsing import make_file_extension_parser
|
|||||||
class QueueMessagePayload:
|
class QueueMessagePayload:
|
||||||
dossier_id: str
|
dossier_id: str
|
||||||
file_id: str
|
file_id: str
|
||||||
|
x_tenant_id: Union[str, None]
|
||||||
|
|
||||||
target_file_extension: str
|
target_file_extension: str
|
||||||
response_file_extension: str
|
response_file_extension: str
|
||||||
|
|
||||||
@ -35,6 +37,7 @@ class QueueMessagePayloadParser:
|
|||||||
dossier_id, file_id, target_file_extension, response_file_extension = itemgetter(
|
dossier_id, file_id, target_file_extension, response_file_extension = itemgetter(
|
||||||
"dossierId", "fileId", "targetFileExtension", "responseFileExtension"
|
"dossierId", "fileId", "targetFileExtension", "responseFileExtension"
|
||||||
)(payload)
|
)(payload)
|
||||||
|
x_tenant_id = payload.get("X-TENANT-ID")
|
||||||
|
|
||||||
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
|
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
|
||||||
map(self.parse_file_extensions, [target_file_extension, response_file_extension])
|
map(self.parse_file_extensions, [target_file_extension, response_file_extension])
|
||||||
@ -46,6 +49,7 @@ class QueueMessagePayloadParser:
|
|||||||
return QueueMessagePayload(
|
return QueueMessagePayload(
|
||||||
dossier_id=dossier_id,
|
dossier_id=dossier_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
|
x_tenant_id=x_tenant_id,
|
||||||
target_file_extension=target_file_extension,
|
target_file_extension=target_file_extension,
|
||||||
response_file_extension=response_file_extension,
|
response_file_extension=response_file_extension,
|
||||||
target_file_type=target_file_type,
|
target_file_type=target_file_type,
|
||||||
|
|||||||
@ -6,16 +6,20 @@ from typing import Callable, Union, List
|
|||||||
from funcy import compose
|
from funcy import compose
|
||||||
|
|
||||||
from pyinfra.config import get_config, Config
|
from pyinfra.config import get_config, Config
|
||||||
from pyinfra.payload_processing.monitor import get_monitor
|
from pyinfra.payload_processing.monitor import get_monitor_from_config
|
||||||
from pyinfra.payload_processing.payload import (
|
from pyinfra.payload_processing.payload import (
|
||||||
QueueMessagePayloadParser,
|
QueueMessagePayloadParser,
|
||||||
get_queue_message_payload_parser,
|
get_queue_message_payload_parser,
|
||||||
QueueMessagePayloadFormatter,
|
QueueMessagePayloadFormatter,
|
||||||
get_queue_message_payload_formatter,
|
get_queue_message_payload_formatter,
|
||||||
)
|
)
|
||||||
from pyinfra.storage import get_storage
|
from pyinfra.storage.storage import make_downloader, make_uploader, get_storage_from_storage_info
|
||||||
from pyinfra.storage.storage import make_downloader, make_uploader
|
from pyinfra.storage.storage_info import (
|
||||||
from pyinfra.storage.storages.interface import Storage
|
AzureStorageInfo,
|
||||||
|
S3StorageInfo,
|
||||||
|
get_storage_info_from_config,
|
||||||
|
get_storage_info_from_endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(get_config().logging_level_root)
|
logger.setLevel(get_config().logging_level_root)
|
||||||
@ -24,8 +28,8 @@ logger.setLevel(get_config().logging_level_root)
|
|||||||
class PayloadProcessor:
|
class PayloadProcessor:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
storage: Storage,
|
default_storage_info: Union[AzureStorageInfo, S3StorageInfo],
|
||||||
bucket: str,
|
get_storage_info_from_tenant_id,
|
||||||
payload_parser: QueueMessagePayloadParser,
|
payload_parser: QueueMessagePayloadParser,
|
||||||
payload_formatter: QueueMessagePayloadFormatter,
|
payload_formatter: QueueMessagePayloadFormatter,
|
||||||
data_processor: Callable,
|
data_processor: Callable,
|
||||||
@ -33,8 +37,9 @@ class PayloadProcessor:
|
|||||||
"""Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps.
|
"""Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
storage: The storage to use for downloading and uploading files
|
default_storage_info: The default storage info used to create the storage connection. This is only used if
|
||||||
bucket: The bucket to use for downloading and uploading files
|
x_tenant_id is not provided in the queue payload.
|
||||||
|
get_storage_info_from_tenant_id: Callable to acquire storage info from a given tenant id.
|
||||||
payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object
|
payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object
|
||||||
payload_formatter: Formatter for the storage upload result and the queue message response body
|
payload_formatter: Formatter for the storage upload result and the queue message response body
|
||||||
data_processor: The analysis function to be called with the downloaded file
|
data_processor: The analysis function to be called with the downloaded file
|
||||||
@ -46,8 +51,10 @@ class PayloadProcessor:
|
|||||||
self.format_to_queue_message_response_body = payload_formatter.format_to_queue_message_response_body
|
self.format_to_queue_message_response_body = payload_formatter.format_to_queue_message_response_body
|
||||||
self.process_data = data_processor
|
self.process_data = data_processor
|
||||||
|
|
||||||
self.make_downloader = partial(make_downloader, storage, bucket)
|
self.get_storage_info_from_tenant_id = get_storage_info_from_tenant_id
|
||||||
self.make_uploader = partial(make_uploader, storage, bucket)
|
self.default_storage_info = default_storage_info
|
||||||
|
# TODO: use lru-dict
|
||||||
|
self.storages = {}
|
||||||
|
|
||||||
def __call__(self, queue_message_payload: dict) -> dict:
|
def __call__(self, queue_message_payload: dict) -> dict:
|
||||||
"""Processes a queue message payload.
|
"""Processes a queue message payload.
|
||||||
@ -71,8 +78,16 @@ class PayloadProcessor:
|
|||||||
payload = self.parse_payload(queue_message_payload)
|
payload = self.parse_payload(queue_message_payload)
|
||||||
logger.info(f"Processing {asdict(payload)} ...")
|
logger.info(f"Processing {asdict(payload)} ...")
|
||||||
|
|
||||||
download_file_to_process = self.make_downloader(payload.target_file_type, payload.target_compression_type)
|
storage_info = self._get_storage_info(payload.x_tenant_id)
|
||||||
upload_processing_result = self.make_uploader(payload.response_file_type, payload.response_compression_type)
|
bucket = storage_info.bucket_name
|
||||||
|
storage = self._get_storage(storage_info)
|
||||||
|
|
||||||
|
download_file_to_process = make_downloader(
|
||||||
|
storage, bucket, payload.target_file_type, payload.target_compression_type
|
||||||
|
)
|
||||||
|
upload_processing_result = make_uploader(
|
||||||
|
storage, bucket, payload.response_file_type, payload.response_compression_type
|
||||||
|
)
|
||||||
format_result_for_storage = partial(self.format_result_for_storage, payload)
|
format_result_for_storage = partial(self.format_result_for_storage, payload)
|
||||||
|
|
||||||
processing_pipeline = compose(format_result_for_storage, self.process_data, download_file_to_process)
|
processing_pipeline = compose(format_result_for_storage, self.process_data, download_file_to_process)
|
||||||
@ -83,17 +98,40 @@ class PayloadProcessor:
|
|||||||
|
|
||||||
return self.format_to_queue_message_response_body(payload)
|
return self.format_to_queue_message_response_body(payload)
|
||||||
|
|
||||||
|
def _get_storage_info(self, x_tenant_id=None):
|
||||||
|
if x_tenant_id:
|
||||||
|
return self.get_storage_info_from_tenant_id(x_tenant_id)
|
||||||
|
return self.default_storage_info
|
||||||
|
|
||||||
|
def _get_storage(self, storage_info):
|
||||||
|
if storage_info in self.storages:
|
||||||
|
return self.storages[storage_info]
|
||||||
|
else:
|
||||||
|
storage = get_storage_from_storage_info(storage_info)
|
||||||
|
self.storages[storage_info] = storage
|
||||||
|
return storage
|
||||||
|
|
||||||
|
|
||||||
def make_payload_processor(data_processor: Callable, config: Union[None, Config] = None) -> PayloadProcessor:
|
def make_payload_processor(data_processor: Callable, config: Union[None, Config] = None) -> PayloadProcessor:
|
||||||
"""Produces payload processor for queue manager."""
|
"""Produces payload processor for queue manager."""
|
||||||
config = config or get_config()
|
config = config or get_config()
|
||||||
|
|
||||||
bucket: str = config.storage_bucket
|
default_storage_info: Union[AzureStorageInfo, S3StorageInfo] = get_storage_info_from_config(config)
|
||||||
storage: Storage = get_storage(config)
|
get_storage_info_from_tenant_id = partial(
|
||||||
monitor = get_monitor(config)
|
get_storage_info_from_endpoint,
|
||||||
|
config.persistence_service_public_key,
|
||||||
|
config.persistence_service_tenant_endpoint,
|
||||||
|
)
|
||||||
|
monitor = get_monitor_from_config(config)
|
||||||
payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config)
|
payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config)
|
||||||
payload_formatter: QueueMessagePayloadFormatter = get_queue_message_payload_formatter()
|
payload_formatter: QueueMessagePayloadFormatter = get_queue_message_payload_formatter()
|
||||||
|
|
||||||
data_processor = monitor(data_processor)
|
data_processor = monitor(data_processor)
|
||||||
|
|
||||||
return PayloadProcessor(storage, bucket, payload_parser, payload_formatter, data_processor)
|
return PayloadProcessor(
|
||||||
|
default_storage_info,
|
||||||
|
get_storage_info_from_tenant_id,
|
||||||
|
payload_parser,
|
||||||
|
payload_formatter,
|
||||||
|
data_processor,
|
||||||
|
)
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from pika.adapters.blocking_connection import BlockingChannel
|
|||||||
from pyinfra.config import Config
|
from pyinfra.config import Config
|
||||||
from pyinfra.exception import ProcessingFailure
|
from pyinfra.exception import ProcessingFailure
|
||||||
from pyinfra.payload_processing.processor import PayloadProcessor
|
from pyinfra.payload_processing.processor import PayloadProcessor
|
||||||
|
from pyinfra.utils.dict import save_project
|
||||||
|
|
||||||
CONFIG = Config()
|
CONFIG = Config()
|
||||||
|
|
||||||
@ -164,8 +165,8 @@ class QueueManager:
|
|||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise ProcessingFailure("QueueMessagePayload processing failed") from err
|
raise ProcessingFailure("QueueMessagePayload processing failed") from err
|
||||||
|
|
||||||
def acknowledge_message_and_publish_response(frame, properties, response_body):
|
def acknowledge_message_and_publish_response(frame, headers, response_body):
|
||||||
response_properties = pika.BasicProperties(headers=properties.headers) if properties.headers else None
|
response_properties = pika.BasicProperties(headers=headers) if headers else None
|
||||||
self._channel.basic_publish("", self._output_queue, json.dumps(response_body).encode(), response_properties)
|
self._channel.basic_publish("", self._output_queue, json.dumps(response_body).encode(), response_properties)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"Result published, acknowledging incoming message with delivery_tag %s",
|
"Result published, acknowledging incoming message with delivery_tag %s",
|
||||||
@ -190,12 +191,15 @@ class QueueManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
self.logger.debug("Processing (%s, %s, %s)", frame, properties, body)
|
self.logger.debug("Processing (%s, %s, %s)", frame, properties, body)
|
||||||
processing_result = process_message_body_and_await_result(json.loads(body))
|
filtered_message_headers = save_project(properties.headers, ["X-TENANT-ID"]) # TODO: parametrize key?
|
||||||
|
message_body = {**json.loads(body), **filtered_message_headers}
|
||||||
|
|
||||||
|
processing_result = process_message_body_and_await_result(message_body)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"Processed message with delivery_tag %s, publishing result to result-queue",
|
"Processed message with delivery_tag %s, publishing result to result-queue",
|
||||||
frame.delivery_tag,
|
frame.delivery_tag,
|
||||||
)
|
)
|
||||||
acknowledge_message_and_publish_response(frame, properties, processing_result)
|
acknowledge_message_and_publish_response(frame, filtered_message_headers, processing_result)
|
||||||
|
|
||||||
except ProcessingFailure:
|
except ProcessingFailure:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
from pyinfra.storage.storage import get_storage
|
from pyinfra.storage.storage import get_storage_from_config
|
||||||
|
|
||||||
__all__ = ["get_storage"]
|
__all__ = ["get_storage_from_config"]
|
||||||
|
|||||||
@ -1,28 +1,45 @@
|
|||||||
from functools import lru_cache, partial
|
from functools import lru_cache, partial
|
||||||
from typing import Callable
|
from typing import Callable, Union
|
||||||
|
|
||||||
|
from azure.storage.blob import BlobServiceClient
|
||||||
from funcy import compose
|
from funcy import compose
|
||||||
|
from minio import Minio
|
||||||
|
|
||||||
from pyinfra.config import Config
|
from pyinfra.config import Config
|
||||||
from pyinfra.storage.storages.azure import get_azure_storage
|
from pyinfra.exception import UnknownStorageBackend
|
||||||
from pyinfra.storage.storages.s3 import get_s3_storage
|
from pyinfra.storage.storage_info import AzureStorageInfo, S3StorageInfo, get_storage_info_from_config
|
||||||
|
from pyinfra.storage.storages.azure import AzureStorage
|
||||||
from pyinfra.storage.storages.interface import Storage
|
from pyinfra.storage.storages.interface import Storage
|
||||||
|
from pyinfra.storage.storages.s3 import S3Storage
|
||||||
from pyinfra.utils.compressing import get_decompressor, get_compressor
|
from pyinfra.utils.compressing import get_decompressor, get_compressor
|
||||||
from pyinfra.utils.encoding import get_decoder, get_encoder
|
from pyinfra.utils.encoding import get_decoder, get_encoder
|
||||||
|
|
||||||
|
|
||||||
def get_storage(config: Config) -> Storage:
|
def get_storage_from_config(config: Config) -> Storage:
|
||||||
|
|
||||||
if config.storage_backend == "s3":
|
storage_info = get_storage_info_from_config(config)
|
||||||
storage = get_s3_storage(config)
|
storage = get_storage_from_storage_info(storage_info)
|
||||||
elif config.storage_backend == "azure":
|
|
||||||
storage = get_azure_storage(config)
|
|
||||||
else:
|
|
||||||
raise Exception(f"Unknown storage backend '{config.storage_backend}'.")
|
|
||||||
|
|
||||||
return storage
|
return storage
|
||||||
|
|
||||||
|
|
||||||
|
def get_storage_from_storage_info(storage_info: Union[AzureStorageInfo, S3StorageInfo]) -> Storage:
|
||||||
|
if isinstance(storage_info, AzureStorageInfo):
|
||||||
|
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string))
|
||||||
|
elif isinstance(storage_info, S3StorageInfo):
|
||||||
|
return S3Storage(
|
||||||
|
Minio(
|
||||||
|
secure=storage_info.secure,
|
||||||
|
endpoint=storage_info.endpoint,
|
||||||
|
access_key=storage_info.access_key,
|
||||||
|
secret_key=storage_info.secret_key,
|
||||||
|
region=storage_info.region,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise UnknownStorageBackend()
|
||||||
|
|
||||||
|
|
||||||
def verify_existence(storage: Storage, bucket: str, file_name: str) -> str:
|
def verify_existence(storage: Storage, bucket: str, file_name: str) -> str:
|
||||||
if not storage.exists(bucket, file_name):
|
if not storage.exists(bucket, file_name):
|
||||||
raise FileNotFoundError(f"{file_name=} name not found on storage in {bucket=}.")
|
raise FileNotFoundError(f"{file_name=} name not found on storage in {bucket=}.")
|
||||||
|
|||||||
89
pyinfra/storage/storage_info.py
Normal file
89
pyinfra/storage/storage_info.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from pyinfra.config import Config
|
||||||
|
from pyinfra.exception import UnknownStorageBackend
|
||||||
|
from pyinfra.utils.cipher import decrypt
|
||||||
|
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AzureStorageInfo:
|
||||||
|
connection_string: str
|
||||||
|
bucket_name: str
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.connection_string)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class S3StorageInfo:
|
||||||
|
secure: bool
|
||||||
|
endpoint: str
|
||||||
|
access_key: str
|
||||||
|
secret_key: str
|
||||||
|
region: str
|
||||||
|
bucket_name: str
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.secure, self.endpoint, self.access_key, self.secret_key, self.region))
|
||||||
|
|
||||||
|
|
||||||
|
def get_storage_info_from_endpoint(
|
||||||
|
public_key: str, endpoint: str, x_tenant_id: str
|
||||||
|
) -> Union[AzureStorageInfo, S3StorageInfo]:
|
||||||
|
# FIXME: parameterize port, host and public_key
|
||||||
|
public_key = "redaction"
|
||||||
|
resp = requests.get(f"{endpoint}/{x_tenant_id}").json()
|
||||||
|
|
||||||
|
maybe_azure = resp.get("azureStorageConnection")
|
||||||
|
maybe_s3 = resp.get("azureStorageConnection")
|
||||||
|
assert not (maybe_azure and maybe_s3)
|
||||||
|
|
||||||
|
if maybe_azure:
|
||||||
|
connection_string = decrypt(public_key, maybe_azure["connectionString"])
|
||||||
|
storage_info = AzureStorageInfo(
|
||||||
|
connection_string=connection_string,
|
||||||
|
bucket_name=maybe_azure["containerName"],
|
||||||
|
)
|
||||||
|
elif maybe_s3:
|
||||||
|
secure, endpoint = validate_and_parse_s3_endpoint(maybe_s3["endpoint"])
|
||||||
|
secret = decrypt(public_key, maybe_s3["secret"])
|
||||||
|
|
||||||
|
storage_info = S3StorageInfo(
|
||||||
|
secure=secure,
|
||||||
|
endpoint=endpoint,
|
||||||
|
access_key=maybe_s3["key"],
|
||||||
|
secret_key=secret,
|
||||||
|
region=maybe_s3["region"],
|
||||||
|
bucket_name=maybe_s3,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise UnknownStorageBackend()
|
||||||
|
|
||||||
|
return storage_info
|
||||||
|
|
||||||
|
|
||||||
|
def get_storage_info_from_config(config: Config) -> Union[AzureStorageInfo, S3StorageInfo]:
|
||||||
|
if config.storage_backend == "s3":
|
||||||
|
storage_info = S3StorageInfo(
|
||||||
|
secure=config.storage_secure_connection,
|
||||||
|
endpoint=config.storage_endpoint,
|
||||||
|
access_key=config.storage_key,
|
||||||
|
secret_key=config.storage_secret,
|
||||||
|
region=config.storage_region,
|
||||||
|
bucket_name=config.storage_bucket,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif config.storage_backend == "azure":
|
||||||
|
storage_info = AzureStorageInfo(
|
||||||
|
connection_string=config.storage_azureconnectionstring,
|
||||||
|
bucket_name=config.storage_bucket,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise UnknownStorageBackend(f"Unknown storage backend '{config.storage_backend}'.")
|
||||||
|
|
||||||
|
return storage_info
|
||||||
@ -77,5 +77,5 @@ class AzureStorage(Storage):
|
|||||||
return zip(repeat(bucket_name), map(attrgetter("name"), blobs))
|
return zip(repeat(bucket_name), map(attrgetter("name"), blobs))
|
||||||
|
|
||||||
|
|
||||||
def get_azure_storage(config: Config):
|
def get_azure_storage_from_config(config: Config):
|
||||||
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=config.storage_azureconnectionstring))
|
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=config.storage_azureconnectionstring))
|
||||||
|
|||||||
@ -67,7 +67,7 @@ class S3Storage(Storage):
|
|||||||
return zip(repeat(bucket_name), map(attrgetter("object_name"), objs))
|
return zip(repeat(bucket_name), map(attrgetter("object_name"), objs))
|
||||||
|
|
||||||
|
|
||||||
def get_s3_storage(config: Config):
|
def get_s3_storage_from_config(config: Config):
|
||||||
return S3Storage(
|
return S3Storage(
|
||||||
Minio(
|
Minio(
|
||||||
secure=config.storage_secure_connection,
|
secure=config.storage_secure_connection,
|
||||||
|
|||||||
5
pyinfra/utils/dict.py
Normal file
5
pyinfra/utils/dict.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from funcy import project
|
||||||
|
|
||||||
|
|
||||||
|
def save_project(mapping, keys) -> dict:
|
||||||
|
return project(mapping, keys) if mapping else {}
|
||||||
@ -7,7 +7,7 @@ import pika
|
|||||||
|
|
||||||
from pyinfra.config import get_config
|
from pyinfra.config import get_config
|
||||||
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
|
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
|
||||||
from pyinfra.storage.storages.s3 import get_s3_storage
|
from pyinfra.storage.storages.s3 import get_s3_storage_from_config
|
||||||
|
|
||||||
CONFIG = get_config()
|
CONFIG = get_config()
|
||||||
logging.basicConfig()
|
logging.basicConfig()
|
||||||
@ -26,7 +26,7 @@ def upload_json_and_make_message_body():
|
|||||||
object_name = f"{dossier_id}/{file_id}.{suffix}"
|
object_name = f"{dossier_id}/{file_id}.{suffix}"
|
||||||
data = gzip.compress(json.dumps(content).encode("utf-8"))
|
data = gzip.compress(json.dumps(content).encode("utf-8"))
|
||||||
|
|
||||||
storage = get_s3_storage(CONFIG)
|
storage = get_s3_storage_from_config(CONFIG)
|
||||||
if not storage.has_bucket(bucket):
|
if not storage.has_bucket(bucket):
|
||||||
storage.make_bucket(bucket)
|
storage.make_bucket(bucket)
|
||||||
storage.put_object(bucket, object_name, data)
|
storage.put_object(bucket, object_name, data)
|
||||||
@ -46,10 +46,10 @@ def main():
|
|||||||
|
|
||||||
message = upload_json_and_make_message_body()
|
message = upload_json_and_make_message_body()
|
||||||
|
|
||||||
development_queue_manager.publish_request(message, pika.BasicProperties(headers={"x-tenant-id": "redaction"}))
|
development_queue_manager.publish_request(message, pika.BasicProperties(headers={"X-TENANT-ID": "redaction"}))
|
||||||
logger.info(f"Put {message} on {CONFIG.request_queue}")
|
logger.info(f"Put {message} on {CONFIG.request_queue}")
|
||||||
|
|
||||||
storage = get_s3_storage(CONFIG)
|
storage = get_s3_storage_from_config(CONFIG)
|
||||||
for method_frame, properties, body in development_queue_manager._channel.consume(
|
for method_frame, properties, body in development_queue_manager._channel.consume(
|
||||||
queue=CONFIG.response_queue, inactivity_timeout=15
|
queue=CONFIG.response_queue, inactivity_timeout=15
|
||||||
):
|
):
|
||||||
|
|||||||
194
test.ipynb
194
test.ipynb
@ -1,194 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"ename": "ModuleNotFoundError",
|
|
||||||
"evalue": "No module named 'pprint.pprint'; 'pprint' is not a package",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
||||||
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
|
||||||
"Cell \u001b[0;32mIn [10], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mloader\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FullLoader\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpprint\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpprint\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpp\u001b[39;00m\n",
|
|
||||||
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pprint.pprint'; 'pprint' is not a package"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import pyinfra\n",
|
|
||||||
"import yaml\n",
|
|
||||||
"from yaml.loader import FullLoader\n",
|
|
||||||
"import pprint"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"{'logging': 0,\n",
|
|
||||||
" 'mock_analysis_endpoint': 'http://127.0.0.1:5000',\n",
|
|
||||||
" 'service': {'operations': {'classify': {'input': {'extension': 'cls_in.gz',\n",
|
|
||||||
" 'multi': True,\n",
|
|
||||||
" 'subdir': ''},\n",
|
|
||||||
" 'output': {'extension': 'cls_out.gz',\n",
|
|
||||||
" 'subdir': ''}},\n",
|
|
||||||
" 'default': {'input': {'extension': 'IN.gz',\n",
|
|
||||||
" 'multi': False,\n",
|
|
||||||
" 'subdir': ''},\n",
|
|
||||||
" 'output': {'extension': 'OUT.gz',\n",
|
|
||||||
" 'subdir': ''}},\n",
|
|
||||||
" 'extract': {'input': {'extension': 'extr_in.gz',\n",
|
|
||||||
" 'multi': False,\n",
|
|
||||||
" 'subdir': ''},\n",
|
|
||||||
" 'output': {'extension': 'gz',\n",
|
|
||||||
" 'subdir': 'extractions'}},\n",
|
|
||||||
" 'rotate': {'input': {'extension': 'rot_in.gz',\n",
|
|
||||||
" 'multi': False,\n",
|
|
||||||
" 'subdir': ''},\n",
|
|
||||||
" 'output': {'extension': 'rot_out.gz',\n",
|
|
||||||
" 'subdir': ''}},\n",
|
|
||||||
" 'stream_pages': {'input': {'extension': 'pgs_in.gz',\n",
|
|
||||||
" 'multi': False,\n",
|
|
||||||
" 'subdir': ''},\n",
|
|
||||||
" 'output': {'extension': 'pgs_out.gz',\n",
|
|
||||||
" 'subdir': 'pages'}},\n",
|
|
||||||
" 'upper': {'input': {'extension': 'up_in.gz',\n",
|
|
||||||
" 'multi': False,\n",
|
|
||||||
" 'subdir': ''},\n",
|
|
||||||
" 'output': {'extension': 'up_out.gz',\n",
|
|
||||||
" 'subdir': ''}}},\n",
|
|
||||||
" 'response_formatter': 'identity'},\n",
|
|
||||||
" 'storage': {'aws': {'access_key': 'AKIA4QVP6D4LCDAGYGN2',\n",
|
|
||||||
" 'endpoint': 'https://s3.amazonaws.com',\n",
|
|
||||||
" 'region': '$STORAGE_REGION|\"eu-west-1\"',\n",
|
|
||||||
" 'secret_key': '8N6H1TUHTsbvW2qMAm7zZlJ63hMqjcXAsdN7TYED'},\n",
|
|
||||||
" 'azure': {'connection_string': 'DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net'},\n",
|
|
||||||
" 'bucket': 'pyinfra-test-bucket',\n",
|
|
||||||
" 'minio': {'access_key': 'root',\n",
|
|
||||||
" 'endpoint': 'http://127.0.0.1:9000',\n",
|
|
||||||
" 'region': None,\n",
|
|
||||||
" 'secret_key': 'password'}},\n",
|
|
||||||
" 'use_docker_fixture': 1,\n",
|
|
||||||
" 'webserver': {'host': '$SERVER_HOST|\"127.0.0.1\"',\n",
|
|
||||||
" 'mode': '$SERVER_MODE|production',\n",
|
|
||||||
" 'port': '$SERVER_PORT|5000'}}\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"\n",
|
|
||||||
"# Open the file and load the file\n",
|
|
||||||
"with open('./tests/config.yml') as f:\n",
|
|
||||||
" data = yaml.load(f, Loader=FullLoader)\n",
|
|
||||||
" pprint.pprint(data)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 22,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import os\n",
|
|
||||||
"os.environ[\"STORAGE_BACKEND\"] = \"azure\"\n",
|
|
||||||
"\n",
|
|
||||||
"# always the same\n",
|
|
||||||
"os.environ[\"STORAGE_BUCKET_NAME\"] = \"pyinfra-test-bucket\"\n",
|
|
||||||
"\n",
|
|
||||||
"# s3\n",
|
|
||||||
"os.environ[\"STORAGE_ENDPOINT\"] = \"https://s3.amazonaws.com\"\n",
|
|
||||||
"os.environ[\"STORAGE_KEY\"] = \"AKIA4QVP6D4LCDAGYGN2\"\n",
|
|
||||||
"os.environ[\"STORAGE_SECRET\"] = \"8N6H1TUHTsbvW2qMAm7zZlJ63hMqjcXAsdN7TYED\"\n",
|
|
||||||
"os.environ[\"STORAGE_REGION\"] = \"eu-west-1\"\n",
|
|
||||||
"\n",
|
|
||||||
"# aks\n",
|
|
||||||
"os.environ[\"STORAGE_AZURECONNECTIONSTRING\"] = \"DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net\""
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 23,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"ename": "Exception",
|
|
||||||
"evalue": "Unknown storage backend 'aks'.",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
||||||
"\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
|
|
||||||
"Cell \u001b[0;32mIn [23], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m config \u001b[38;5;241m=\u001b[39m pyinfra\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mget_config()\n\u001b[0;32m----> 2\u001b[0m storage \u001b[38;5;241m=\u001b[39m \u001b[43mpyinfra\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstorage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_storage\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
||||||
"File \u001b[0;32m~/dev/pyinfra/pyinfra/storage/storage.py:15\u001b[0m, in \u001b[0;36mget_storage\u001b[0;34m(config)\u001b[0m\n\u001b[1;32m 13\u001b[0m storage \u001b[39m=\u001b[39m get_azure_storage(config)\n\u001b[1;32m 14\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 15\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mException\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mUnknown storage backend \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mconfig\u001b[39m.\u001b[39mstorage_backend\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 17\u001b[0m \u001b[39mreturn\u001b[39;00m storage\n",
|
|
||||||
"\u001b[0;31mException\u001b[0m: Unknown storage backend 'aks'."
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"config = pyinfra.config.get_config()\n",
|
|
||||||
"storage = pyinfra.storage.get_storage(config)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 21,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"False"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 21,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"storage.has_bucket(config.storage_bucket)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3.8.13 ('pyinfra-TboPpZ8z-py3.8')",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.8.13"
|
|
||||||
},
|
|
||||||
"orig_nbformat": 4,
|
|
||||||
"vscode": {
|
|
||||||
"interpreter": {
|
|
||||||
"hash": "10d7419af5ea6dfec0078ebc9d6fa1a9383fe9894853f90dc7d29a81b3de2c78"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
||||||
@ -6,8 +6,7 @@ import pytest
|
|||||||
import testcontainers.compose
|
import testcontainers.compose
|
||||||
|
|
||||||
from pyinfra.config import get_config
|
from pyinfra.config import get_config
|
||||||
from pyinfra.queue.queue_manager import QueueManager
|
from pyinfra.storage import get_storage_from_config
|
||||||
from pyinfra.storage import get_storage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
@ -30,60 +29,35 @@ def docker_compose(sleep_seconds=30):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def storage_config(client_name):
|
def test_storage_config(storage_backend, bucket_name, monitoring_enabled):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
config.storage_backend = client_name
|
config.storage_backend = storage_backend
|
||||||
|
config.storage_bucket = bucket_name
|
||||||
config.storage_azureconnectionstring = "DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net"
|
config.storage_azureconnectionstring = "DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net"
|
||||||
|
config.monitoring_enabled = monitoring_enabled
|
||||||
|
config.prometheus_metric_prefix = "test"
|
||||||
|
config.prometheus_port = 8080
|
||||||
|
config.prometheus_host = "0.0.0.0"
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def processing_config(storage_config, monitoring_enabled):
|
def test_queue_config():
|
||||||
storage_config.monitoring_enabled = monitoring_enabled
|
|
||||||
return storage_config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def bucket_name(storage_config):
|
|
||||||
return storage_config.storage_bucket
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def storage(storage_config):
|
|
||||||
logger.debug("Setup for storage")
|
|
||||||
storage = get_storage(storage_config)
|
|
||||||
storage.make_bucket(storage_config.storage_bucket)
|
|
||||||
storage.clear_bucket(storage_config.storage_bucket)
|
|
||||||
yield storage
|
|
||||||
logger.debug("Teardown for storage")
|
|
||||||
try:
|
|
||||||
storage.clear_bucket(storage_config.storage_bucket)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def queue_config(payload_processor_type):
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
# FIXME: It looks like rabbitmq_heartbeat has to be greater than rabbitmq_connection_sleep. If this is expected, the
|
config.rabbitmq_connection_sleep = 2
|
||||||
# user should not be abele to insert non working values.
|
config.rabbitmq_heartbeat = 4
|
||||||
config.rabbitmq_heartbeat = config.rabbitmq_connection_sleep + 1
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def queue_manager(queue_config):
|
|
||||||
queue_manager = QueueManager(queue_config)
|
|
||||||
return queue_manager
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def request_payload():
|
def payload(x_tenant_id):
|
||||||
|
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
|
||||||
return {
|
return {
|
||||||
"dossierId": "test",
|
"dossierId": "test",
|
||||||
"fileId": "test",
|
"fileId": "test",
|
||||||
"targetFileExtension": "json.gz",
|
"targetFileExtension": "json.gz",
|
||||||
"responseFileExtension": "json.gz",
|
"responseFileExtension": "json.gz",
|
||||||
|
**x_tenant_entry,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -93,3 +67,17 @@ def response_payload():
|
|||||||
"dossierId": "test",
|
"dossierId": "test",
|
||||||
"fileId": "test",
|
"fileId": "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def storage(test_storage_config):
|
||||||
|
logger.debug("Setup for storage")
|
||||||
|
storage = get_storage_from_config(test_storage_config)
|
||||||
|
storage.make_bucket(test_storage_config.storage_bucket)
|
||||||
|
storage.clear_bucket(test_storage_config.storage_bucket)
|
||||||
|
yield storage
|
||||||
|
logger.debug("Teardown for storage")
|
||||||
|
try:
|
||||||
|
storage.clear_bucket(test_storage_config.storage_bucket)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|||||||
@ -4,43 +4,41 @@ import time
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from pyinfra.config import get_config
|
from pyinfra.payload_processing.monitor import PrometheusMonitor
|
||||||
from pyinfra.payload_processing.monitor import get_monitor
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="class")
|
@pytest.fixture(scope="class")
|
||||||
def monitor_config():
|
def monitored_mock_function(metric_prefix, host, port):
|
||||||
config = get_config()
|
|
||||||
config.prometheus_metric_prefix = "monitor_test"
|
|
||||||
config.prometheus_port = 8000
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="class")
|
|
||||||
def prometheus_monitor(monitor_config):
|
|
||||||
return get_monitor(monitor_config)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def monitored_mock_function(prometheus_monitor):
|
|
||||||
def process(data=None):
|
def process(data=None):
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
return ["result1", "result2", "result3"]
|
return ["result1", "result2", "result3"]
|
||||||
|
|
||||||
return prometheus_monitor(process)
|
monitor = PrometheusMonitor(metric_prefix, host, port)
|
||||||
|
return monitor(process)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def metric_endpoint(host, port):
|
||||||
|
return f"http://{host}:{port}/prometheus"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("metric_prefix, host, port", [("test", "0.0.0.0", 8000)], scope="class")
|
||||||
class TestPrometheusMonitor:
|
class TestPrometheusMonitor:
|
||||||
def test_prometheus_endpoint_is_available(self, prometheus_monitor, monitor_config):
|
def test_prometheus_endpoint_is_available(self, metric_endpoint, monitored_mock_function):
|
||||||
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus")
|
resp = requests.get(metric_endpoint)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
def test_processing_with_a_monitored_fn_increases_parameter_counter(self, monitored_mock_function, monitor_config):
|
def test_processing_with_a_monitored_fn_increases_parameter_counter(
|
||||||
|
self,
|
||||||
|
metric_endpoint,
|
||||||
|
metric_prefix,
|
||||||
|
monitored_mock_function,
|
||||||
|
):
|
||||||
monitored_mock_function(data=None)
|
monitored_mock_function(data=None)
|
||||||
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus")
|
resp = requests.get(metric_endpoint)
|
||||||
pattern = re.compile(r".*monitor_test_processing_time_count (\d\.\d).*")
|
pattern = re.compile(rf".*{metric_prefix}_processing_time_count (\d\.\d).*")
|
||||||
assert pattern.search(resp.text).group(1) == "1.0"
|
assert pattern.search(resp.text).group(1) == "1.0"
|
||||||
|
|
||||||
monitored_mock_function(data=None)
|
monitored_mock_function(data=None)
|
||||||
resp = requests.get(f"http://{monitor_config.prometheus_host}:{monitor_config.prometheus_port}/prometheus")
|
resp = requests.get(metric_endpoint)
|
||||||
assert pattern.search(resp.text).group(1) == "2.0"
|
assert pattern.search(resp.text).group(1) == "2.0"
|
||||||
|
|||||||
53
tests/payload_parsing_test.py
Normal file
53
tests/payload_parsing_test.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from pyinfra.payload_processing.payload import (
|
||||||
|
QueueMessagePayload,
|
||||||
|
QueueMessagePayloadParser,
|
||||||
|
)
|
||||||
|
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_parsed_payload(x_tenant_id):
|
||||||
|
return QueueMessagePayload(
|
||||||
|
dossier_id="test",
|
||||||
|
file_id="test",
|
||||||
|
x_tenant_id=x_tenant_id,
|
||||||
|
target_file_extension="json.gz",
|
||||||
|
response_file_extension="json.gz",
|
||||||
|
target_file_type="json",
|
||||||
|
target_compression_type="gz",
|
||||||
|
response_file_type="json",
|
||||||
|
response_compression_type="gz",
|
||||||
|
target_file_name="test/test.json.gz",
|
||||||
|
response_file_name="test/test.json.gz",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def file_extension_parser(allowed_file_types, allowed_compression_types):
|
||||||
|
return make_file_extension_parser(allowed_file_types, allowed_compression_types)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def payload_parser(file_extension_parser):
|
||||||
|
return QueueMessagePayloadParser(file_extension_parser)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])])
|
||||||
|
class TestPayload:
|
||||||
|
@pytest.mark.parametrize("x_tenant_id", [None, "klaus"])
|
||||||
|
def test_payload_is_parsed_correctly(self, payload_parser, payload, expected_parsed_payload):
|
||||||
|
payload = payload_parser(payload)
|
||||||
|
assert payload == expected_parsed_payload
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"extension,expected",
|
||||||
|
[
|
||||||
|
("json.gz", ("json", "gz")),
|
||||||
|
("json", ("json", None)),
|
||||||
|
("prefix.json.gz", ("json", "gz")),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_file_extension(self, file_extension_parser, extension, expected):
|
||||||
|
assert file_extension_parser(extension) == expected
|
||||||
@ -8,14 +8,6 @@ import requests
|
|||||||
from pyinfra.payload_processing.processor import make_payload_processor
|
from pyinfra.payload_processing.processor import make_payload_processor
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def file_processor_mock():
|
|
||||||
def inner(json_file: dict):
|
|
||||||
return [json_file]
|
|
||||||
|
|
||||||
return inner
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def target_file():
|
def target_file():
|
||||||
contents = {"numberOfPages": 10, "content1": "value1", "content2": "value2"}
|
contents = {"numberOfPages": 10, "content1": "value1", "content2": "value2"}
|
||||||
@ -23,54 +15,60 @@ def target_file():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def file_names(request_payload):
|
def file_names(payload):
|
||||||
dossier_id, file_id, target_suffix, response_suffix = itemgetter(
|
dossier_id, file_id, target_suffix, response_suffix = itemgetter(
|
||||||
"dossierId",
|
"dossierId",
|
||||||
"fileId",
|
"fileId",
|
||||||
"targetFileExtension",
|
"targetFileExtension",
|
||||||
"responseFileExtension",
|
"responseFileExtension",
|
||||||
)(request_payload)
|
)(payload)
|
||||||
return f"{dossier_id}/{file_id}.{target_suffix}", f"{dossier_id}/{file_id}.{response_suffix}"
|
return f"{dossier_id}/{file_id}.{target_suffix}", f"{dossier_id}/{file_id}.{response_suffix}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def payload_processor(file_processor_mock, processing_config):
|
def payload_processor(test_storage_config):
|
||||||
yield make_payload_processor(file_processor_mock, processing_config)
|
def file_processor_mock(json_file: dict):
|
||||||
|
return [json_file]
|
||||||
|
|
||||||
|
yield make_payload_processor(file_processor_mock, test_storage_config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("client_name", ["s3"], scope="session")
|
@pytest.mark.parametrize("storage_backend", ["s3"], scope="session")
|
||||||
|
@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session")
|
||||||
@pytest.mark.parametrize("monitoring_enabled", [True, False], scope="session")
|
@pytest.mark.parametrize("monitoring_enabled", [True, False], scope="session")
|
||||||
|
@pytest.mark.parametrize("x_tenant_id", [None])
|
||||||
class TestPayloadProcessor:
|
class TestPayloadProcessor:
|
||||||
def test_payload_processor_yields_correct_response_and_uploads_result(
|
def test_payload_processor_yields_correct_response_and_uploads_result(
|
||||||
self,
|
self,
|
||||||
payload_processor,
|
payload_processor,
|
||||||
storage,
|
storage,
|
||||||
bucket_name,
|
bucket_name,
|
||||||
request_payload,
|
payload,
|
||||||
response_payload,
|
response_payload,
|
||||||
target_file,
|
target_file,
|
||||||
file_names,
|
file_names,
|
||||||
):
|
):
|
||||||
storage.clear_bucket(bucket_name)
|
storage.clear_bucket(bucket_name)
|
||||||
storage.put_object(bucket_name, file_names[0], target_file)
|
storage.put_object(bucket_name, file_names[0], target_file)
|
||||||
response = payload_processor(request_payload)
|
response = payload_processor(payload)
|
||||||
|
|
||||||
assert response == response_payload
|
assert response == response_payload
|
||||||
|
|
||||||
data_received = storage.get_object(bucket_name, file_names[1])
|
data_received = storage.get_object(bucket_name, file_names[1])
|
||||||
|
|
||||||
assert json.loads((gzip.decompress(data_received)).decode("utf-8")) == {
|
assert json.loads((gzip.decompress(data_received)).decode("utf-8")) == {
|
||||||
**request_payload,
|
**payload,
|
||||||
"data": [json.loads(gzip.decompress(target_file).decode("utf-8"))],
|
"data": [json.loads(gzip.decompress(target_file).decode("utf-8"))],
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_catching_of_processing_failure(self, payload_processor, storage, bucket_name, request_payload):
|
def test_catching_of_processing_failure(self, payload_processor, storage, bucket_name, payload):
|
||||||
storage.clear_bucket(bucket_name)
|
storage.clear_bucket(bucket_name)
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
payload_processor(request_payload)
|
payload_processor(payload)
|
||||||
|
|
||||||
def test_prometheus_endpoint_is_available(self, processing_config):
|
def test_prometheus_endpoint_is_available(self, test_storage_config, monitoring_enabled, storage_backend, x_tenant_id):
|
||||||
|
if monitoring_enabled:
|
||||||
resp = requests.get(
|
resp = requests.get(
|
||||||
f"http://{processing_config.prometheus_host}:{processing_config.prometheus_port}/prometheus"
|
f"http://{test_storage_config.prometheus_host}:{test_storage_config.prometheus_port}/prometheus"
|
||||||
)
|
)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
@ -1,44 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from pyinfra.config import get_config
|
|
||||||
from pyinfra.payload_processing.payload import (
|
|
||||||
QueueMessagePayload,
|
|
||||||
get_queue_message_payload_parser,
|
|
||||||
)
|
|
||||||
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def payload_config():
|
|
||||||
return get_config()
|
|
||||||
|
|
||||||
|
|
||||||
class TestPayload:
|
|
||||||
def test_payload_is_parsed_correctly(self, request_payload, payload_config):
|
|
||||||
parse_payload = get_queue_message_payload_parser(payload_config)
|
|
||||||
payload = parse_payload(request_payload)
|
|
||||||
assert payload == QueueMessagePayload(
|
|
||||||
dossier_id="test",
|
|
||||||
file_id="test",
|
|
||||||
target_file_extension="json.gz",
|
|
||||||
response_file_extension="json.gz",
|
|
||||||
target_file_type="json",
|
|
||||||
target_compression_type="gz",
|
|
||||||
response_file_type="json",
|
|
||||||
response_compression_type="gz",
|
|
||||||
target_file_name="test/test.json.gz",
|
|
||||||
response_file_name="test/test.json.gz",
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"extension,expected",
|
|
||||||
[
|
|
||||||
("json.gz", ("json", "gz")),
|
|
||||||
("json", ("json", None)),
|
|
||||||
("prefix.json.gz", ("json", "gz")),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("allowed_file_types,allowed_compression_types", [(["json", "pdf"], ["gz"])])
|
|
||||||
def test_parse_file_extension(self, extension, expected, allowed_file_types, allowed_compression_types):
|
|
||||||
parse = make_file_extension_parser(allowed_file_types, allowed_compression_types)
|
|
||||||
assert parse(extension) == expected
|
|
||||||
@ -8,15 +8,16 @@ import pika.exceptions
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
|
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
|
||||||
|
from pyinfra.queue.queue_manager import QueueManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def development_queue_manager(queue_config):
|
def development_queue_manager(test_queue_config):
|
||||||
queue_config.rabbitmq_heartbeat = 7200
|
test_queue_config.rabbitmq_heartbeat = 7200
|
||||||
development_queue_manager = DevelopmentQueueManager(queue_config)
|
development_queue_manager = DevelopmentQueueManager(test_queue_config)
|
||||||
yield development_queue_manager
|
yield development_queue_manager
|
||||||
logger.info("Tearing down development queue manager...")
|
logger.info("Tearing down development queue manager...")
|
||||||
try:
|
try:
|
||||||
@ -26,10 +27,10 @@ def development_queue_manager(queue_config):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def payload_processing_time(queue_config, offset=5):
|
def payload_processing_time(test_queue_config, offset=5):
|
||||||
# FIXME: this implicitly tests the heartbeat when running the end-to-end test. There should be another way to test
|
# FIXME: this implicitly tests the heartbeat when running the end-to-end test. There should be another way to test
|
||||||
# this explicitly.
|
# this explicitly.
|
||||||
return queue_config.rabbitmq_heartbeat + offset
|
return test_queue_config.rabbitmq_heartbeat + offset
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -48,10 +49,11 @@ def payload_processor(response_payload, payload_processing_time, payload_process
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def start_queue_consumer(queue_manager, payload_processor, sleep_seconds=5):
|
def start_queue_consumer(test_queue_config, payload_processor, sleep_seconds=5):
|
||||||
def consume_queue():
|
def consume_queue():
|
||||||
queue_manager.start_consuming(payload_processor)
|
queue_manager.start_consuming(payload_processor)
|
||||||
|
|
||||||
|
queue_manager = QueueManager(test_queue_config)
|
||||||
p = Process(target=consume_queue)
|
p = Process(target=consume_queue)
|
||||||
p.start()
|
p.start()
|
||||||
logger.info(f"Setting up consumer, waiting for {sleep_seconds}...")
|
logger.info(f"Setting up consumer, waiting for {sleep_seconds}...")
|
||||||
@ -65,39 +67,40 @@ def start_queue_consumer(queue_manager, payload_processor, sleep_seconds=5):
|
|||||||
def message_properties(message_headers):
|
def message_properties(message_headers):
|
||||||
if not message_headers:
|
if not message_headers:
|
||||||
return pika.BasicProperties(headers=None)
|
return pika.BasicProperties(headers=None)
|
||||||
elif message_headers == "x-tenant-id":
|
elif message_headers == "X-TENANT-ID":
|
||||||
return pika.BasicProperties(headers={"x-tenant-id": "redaction"})
|
return pika.BasicProperties(headers={"X-TENANT-ID": "redaction"})
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid {message_headers=}.")
|
raise Exception(f"Invalid {message_headers=}.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("x_tenant_id", [None])
|
||||||
class TestQueueManager:
|
class TestQueueManager:
|
||||||
# FIXME: All tests here are wonky. This is due to the implementation of running the process-blocking queue_manager
|
# FIXME: All tests here are wonky. This is due to the implementation of running the process-blocking queue_manager
|
||||||
# in a subprocess. It is then very hard to interact directly with the subprocess. If you have a better idea, please
|
# in a subprocess. It is then very hard to interact directly with the subprocess. If you have a better idea, please
|
||||||
# refactor; the tests here are insufficient to ensure the functionality of the queue manager!
|
# refactor; the tests here are insufficient to ensure the functionality of the queue manager!
|
||||||
@pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session")
|
@pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session")
|
||||||
def test_message_processing_does_not_block_heartbeat(
|
def test_message_processing_does_not_block_heartbeat(
|
||||||
self, development_queue_manager, request_payload, response_payload, payload_processing_time
|
self, development_queue_manager, payload, response_payload, payload_processing_time
|
||||||
):
|
):
|
||||||
development_queue_manager.clear_queues()
|
development_queue_manager.clear_queues()
|
||||||
development_queue_manager.publish_request(request_payload)
|
development_queue_manager.publish_request(payload)
|
||||||
time.sleep(payload_processing_time + 10)
|
time.sleep(payload_processing_time + 10)
|
||||||
_, _, body = development_queue_manager.get_response()
|
_, _, body = development_queue_manager.get_response()
|
||||||
result = json.loads(body)
|
result = json.loads(body)
|
||||||
assert result == response_payload
|
assert result == response_payload
|
||||||
|
|
||||||
@pytest.mark.parametrize("message_headers", [None, "x-tenant-id"])
|
@pytest.mark.parametrize("message_headers", [None, "X-TENANT-ID"])
|
||||||
@pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session")
|
@pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session")
|
||||||
def test_queue_manager_forwards_message_headers(
|
def test_queue_manager_forwards_message_headers(
|
||||||
self,
|
self,
|
||||||
development_queue_manager,
|
development_queue_manager,
|
||||||
request_payload,
|
payload,
|
||||||
response_payload,
|
response_payload,
|
||||||
payload_processing_time,
|
payload_processing_time,
|
||||||
message_properties,
|
message_properties,
|
||||||
):
|
):
|
||||||
development_queue_manager.clear_queues()
|
development_queue_manager.clear_queues()
|
||||||
development_queue_manager.publish_request(request_payload, message_properties)
|
development_queue_manager.publish_request(payload, message_properties)
|
||||||
time.sleep(payload_processing_time + 10)
|
time.sleep(payload_processing_time + 10)
|
||||||
_, properties, _ = development_queue_manager.get_response()
|
_, properties, _ = development_queue_manager.get_response()
|
||||||
assert properties.headers == message_properties.headers
|
assert properties.headers == message_properties.headers
|
||||||
@ -109,12 +112,12 @@ class TestQueueManager:
|
|||||||
def test_failed_message_processing_is_handled(
|
def test_failed_message_processing_is_handled(
|
||||||
self,
|
self,
|
||||||
development_queue_manager,
|
development_queue_manager,
|
||||||
request_payload,
|
payload,
|
||||||
response_payload,
|
response_payload,
|
||||||
payload_processing_time,
|
payload_processing_time,
|
||||||
):
|
):
|
||||||
development_queue_manager.clear_queues()
|
development_queue_manager.clear_queues()
|
||||||
development_queue_manager.publish_request(request_payload)
|
development_queue_manager.publish_request(payload)
|
||||||
time.sleep(payload_processing_time + 10)
|
time.sleep(payload_processing_time + 10)
|
||||||
_, _, body = development_queue_manager.get_response()
|
_, _, body = development_queue_manager.get_response()
|
||||||
assert not body
|
assert not body
|
||||||
@ -6,7 +6,9 @@ logger = logging.getLogger(__name__)
|
|||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("client_name", ["azure", "s3"], scope="session")
|
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="session")
|
||||||
|
@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session")
|
||||||
|
@pytest.mark.parametrize("monitoring_enabled", [False], scope="session")
|
||||||
class TestStorage:
|
class TestStorage:
|
||||||
def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name):
|
def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name):
|
||||||
storage.clear_bucket(bucket_name)
|
storage.clear_bucket(bucket_name)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user