Merge branch 'opentel' into 'master'

RES-506, RES-507, RES-499, RES-434, RES-398

See merge request knecon/research/pyinfra!82
This commit is contained in:
Julius Unverfehrt 2024-01-31 11:21:17 +01:00
commit dc413cea82
53 changed files with 2900 additions and 2280 deletions

View File

@ -8,4 +8,4 @@ default:
run-tests:
script:
- pytest .
- echo "Disabled until we have an automated way to run docker compose before tests."

View File

@ -2,10 +2,10 @@
# See https://pre-commit.com/hooks.html for more hooks
exclude: ^(docs/|notebooks/|data/|src/secrets/|src/static/|src/templates/|tests)
default_language_version:
python: python3.8
python: python3.10
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
@ -26,7 +26,7 @@ repos:
args: ["--profile", "black"]
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 23.12.1
hooks:
- id: black
# exclude: ^(docs/|notebooks/|data/|src/secrets/)

209
README.md
View File

@ -2,75 +2,126 @@
1. [ About ](#about)
2. [ Configuration ](#configuration)
3. [ Response Format ](#response-format)
4. [ Usage & API ](#usage--api)
3. [ Queue Manager ](#queue-manager)
4. [ Module Installation ](#module-installation)
5. [ Scripts ](#scripts)
6. [ Tests ](#tests)
## About
Common Module with the infrastructure to deploy Research Projects.
The Infrastructure expects to be deployed in the same Pod / local environment as the analysis container and handles all outbound communication.
Shared library for the research team, containing code related to infrastructure and communication with other services.
Offers a simple interface for processing data and sending responses via AMQP, monitoring via Prometheus and storage
access via S3 or Azure. Also export traces via OpenTelemetry for queue messages and webserver requests.
To start, see the [complete example](pyinfra/examples.py) which shows how to use all features of the service and can be
imported and used directly for default research service pipelines (data ID in message, download data from storage,
upload result while offering Prometheus monitoring, /health and /ready endpoints and multi tenancy support).
## Configuration
A configuration is located in `/config.yaml`. All relevant variables can be configured via exporting environment variables.
Configuration is done via `Dynaconf`. This means that you can use environment variables, a `.env` file or `.toml`
file(s) to configure the service. You can also combine these methods. The precedence is
`environment variables > .env > .toml`. It is recommended to load settings with the provided
[`load_settings`](pyinfra/config/loader.py) function, which you can combine with the provided
[`parse_args`](pyinfra/config/loader.py) function. This allows you to load settings from a `.toml` file or a folder with
`.toml` files and override them with environment variables.
| Environment Variable | Default | Description |
|-------------------------------|----------------------------------|--------------------------------------------------------------------------|
| LOGGING_LEVEL_ROOT | "DEBUG" | Logging level for service logger |
| MONITORING_ENABLED | True | Enables Prometheus monitoring |
| PROMETHEUS_METRIC_PREFIX | "redactmanager_research_service" | Prometheus metric prefix, per convention '{product_name}_{service name}' |
| PROMETHEUS_HOST | "127.0.0.1" | Prometheus webserver address |
| PROMETHEUS_PORT | 8080 | Prometheus webserver port |
| RABBITMQ_HOST | "localhost" | RabbitMQ host address |
| RABBITMQ_PORT | "5672" | RabbitMQ host port |
| RABBITMQ_USERNAME | "user" | RabbitMQ username |
| RABBITMQ_PASSWORD | "bitnami" | RabbitMQ password |
| RABBITMQ_HEARTBEAT | 60 | Controls AMQP heartbeat timeout in seconds |
| RABBITMQ_CONNECTION_SLEEP | 5 | Controls AMQP connection sleep timer in seconds |
| REQUEST_QUEUE | "request_queue" | Requests to service |
| RESPONSE_QUEUE | "response_queue" | Responses by service |
| DEAD_LETTER_QUEUE | "dead_letter_queue" | Messages that failed to process |
| STORAGE_BACKEND | "s3" | The type of storage to use {s3, azure} |
| STORAGE_BUCKET | "redaction" | The bucket / container to pull files specified in queue requests from |
| STORAGE_ENDPOINT | "http://127.0.0.1:9000" | Endpoint for s3 storage |
| STORAGE_KEY | "root" | User for s3 storage |
| STORAGE_SECRET | "password" | Password for s3 storage |
| STORAGE_AZURECONNECTIONSTRING | "DefaultEndpointsProtocol=..." | Connection string for Azure storage |
| STORAGE_AZURECONTAINERNAME | "redaction" | AKS container |
| WRITE_CONSUMER_TOKEN | "False" | Value to see if we should write a consumer token to a file |
The following table shows all necessary settings. You can find a preconfigured settings file for this service in
bitbucket. These are the complete settings, you only need all if using all features of the service as described in
the [complete example](pyinfra/examples.py).
## Response Format
| Environment Variable | Internal / .toml Name | Description |
|--------------------------------------|------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| LOGGING__LEVEL | logging.level | Log level |
| METRICS__PROMETHEUS__ENABLED | metrics.prometheus.enabled | Enable Prometheus metrics collection |
| METRICS__PROMETHEUS__PREFIX | metrics.prometheus.prefix | Prefix for Prometheus metrics (e.g. {product}-{service}) |
| WEBSERVER__HOST | webserver.host | Host of the webserver (offering e.g. /prometheus, /ready and /health endpoints) |
| WEBSERVER__PORT | webserver.port | Port of the webserver |
| RABBITMQ__HOST | rabbitmq.host | Host of the RabbitMQ server |
| RABBITMQ__PORT | rabbitmq.port | Port of the RabbitMQ server |
| RABBITMQ__USERNAME | rabbitmq.username | Username for the RabbitMQ server |
| RABBITMQ__PASSWORD | rabbitmq.password | Password for the RabbitMQ server |
| RABBITMQ__HEARTBEAT | rabbitmq.heartbeat | Heartbeat for the RabbitMQ server |
| RABBITMQ__CONNECTION_SLEEP | rabbitmq.connection_sleep | Sleep time intervals during message processing. Has to be a divider of heartbeat, and shouldn't be too big, since only in these intervals queue interactions happen (like receiving new messages) This is also the minimum time the service needs to process a message. |
| RABBITMQ__INPUT_QUEUE | rabbitmq.input_queue | Name of the input queue |
| RABBITMQ__OUTPUT_QUEUE | rabbitmq.output_queue | Name of the output queue |
| RABBITMQ__DEAD_LETTER_QUEUE | rabbitmq.dead_letter_queue | Name of the dead letter queue |
| STORAGE__BACKEND | storage.backend | Storage backend to use (currently only "s3" and "azure" are supported) |
| STORAGE__S3__BUCKET | storage.s3.bucket | Name of the S3 bucket |
| STORAGE__S3__ENDPOINT | storage.s3.endpoint | Endpoint of the S3 server |
| STORAGE__S3__KEY | storage.s3.key | Access key for the S3 server |
| STORAGE__S3__SECRET | storage.s3.secret | Secret key for the S3 server |
| STORAGE__S3__REGION | storage.s3.region | Region of the S3 server |
| STORAGE__AZURE__CONTAINER | storage.azure.container_name | Name of the Azure container |
| STORAGE__AZURE__CONNECTION_STRING | storage.azure.connection_string | Connection string for the Azure server |
| STORAGE__TENANT_SERVER__PUBLIC_KEY | storage.tenant_server.public_key | Public key of the tenant server |
| STORAGE__TENANT_SERVER__ENDPOINT | storage.tenant_server.endpoint | Endpoint of the tenant server |
| TRACING__OPENTELEMETRY__ENDPOINT | tracing.opentelemetry.endpoint | Endpoint to which OpenTelemetry traces are exported
| TRACING__OPENTELEMETRY__SERVICE_NAME | tracing.opentelemetry.service_name | Name of the service as displayed in the traces collected
### OpenTelemetry
Open telemetry (vis its Python SDK) is set up to be as unobtrusive as possible; for typical use cases it can be
configured
from environment variables, without additional work in the microservice app, although additional confiuration is
possible.
`TRACING__OPENTELEMETRY__ENDPOINT` should typically be set
to `http://otel-collector-opentelemetry-collector.otel-collector:4318/v1/traces`.
## Queue Manager
The queue manager is responsible for consuming messages from the input queue, processing them and sending the response
to the output queue. The default callback also downloads data from the storage and uploads the result to the storage.
The response message does not contain the data itself, but the identifiers from the input message (including headers
beginning with "X-").
### Standalone Usage
```python
from pyinfra.queue.manager import QueueManager
from pyinfra.queue.callback import make_download_process_upload_callback, DataProcessor
from pyinfra.config.loader import load_settings
settings = load_settings("path/to/settings")
processing_function: DataProcessor # function should expect a dict (json) or bytes (pdf) as input and should return a json serializable object.
queue_manager = QueueManager(settings)
callback = make_download_process_upload_callback(processing_function, settings)
queue_manager.start_consuming(make_download_process_upload_callback(callback, settings))
```
### Usage in a Service
This is the recommended way to use the module. This includes the webserver, Prometheus metrics and health endpoints.
Custom endpoints can be added by adding a new route to the `app` object beforehand. Settings are loaded from files
specified as CLI arguments (e.g. `--settings-path path/to/settings.toml`). The values can also be set or overriden via
environment variables (e.g. `LOGGING__LEVEL=DEBUG`).
The callback can be replaced with a custom one, for example if the data to process is contained in the message itself
and not on the storage.
```python
from pyinfra.config.loader import load_settings, parse_settings_path
from pyinfra.examples import start_standard_queue_consumer
from pyinfra.queue.callback import make_download_process_upload_callback, DataProcessor
processing_function: DataProcessor
arguments = parse_settings_path()
settings = load_settings(arguments.settings_path)
callback = make_download_process_upload_callback(processing_function, settings)
start_standard_queue_consumer(callback, settings) # optionally also pass a fastAPI app object with preconfigured routes
```
### AMQP input message:
### Expected AMQP input message:
Either use the legacy format with dossierId and fileId as strings or the new format where absolute paths are used.
A tenant ID can be optionally provided in the message header (key: "X-TENANT-ID")
```json
{
"targetFilePath": "",
"responseFilePath": ""
}
```
or
```json
{
"dossierId": "",
"fileId": "",
"targetFileExtension": "",
"responseFileExtension": ""
}
```
Optionally, the input message can contain a field with the key `"operations"`.
### AMQP output message:
All headers beginning with "X-" are forwarded to the message processor, and returned in the response message (e.g.
"X-TENANT-ID" is used to acquire storage information for the tenant).
```json
{
@ -84,19 +135,21 @@ or
```json
{
"dossierId": "",
"fileId": ""
"fileId": "",
"targetFileExtension": "",
"responseFileExtension": ""
}
```
## Usage & API
## Module Installation
### Setup
Add the respective version of the pyinfra package to your pyproject.toml file. Make sure to add our gitlab registry as a source.
For now, all internal packages used by pyinfra also have to be added to the pyproject.toml file.
Add the respective version of the pyinfra package to your pyproject.toml file. Make sure to add our gitlab registry as a
source.
For now, all internal packages used by pyinfra also have to be added to the pyproject.toml file (namely kn-utils).
Execute `poetry lock` and `poetry install` to install the packages.
You can look up the latest version of the package in the [gitlab registry](https://gitlab.knecon.com/knecon/research/pyinfra/-/packages).
You can look up the latest version of the package in
the [gitlab registry](https://gitlab.knecon.com/knecon/research/pyinfra/-/packages).
For the used versions of internal dependencies, please refer to the [pyproject.toml](pyproject.toml) file.
```toml
@ -110,45 +163,29 @@ url = "https://gitlab.knecon.com/api/v4/groups/19/-/packages/pypi/simple"
priority = "explicit"
```
### API
```python
from pyinfra import config
from pyinfra.payload_processing.processor import make_payload_processor
from pyinfra.queue.queue_manager import QueueManager
pyinfra_config = config.get_config()
process_payload = make_payload_processor(process_data, config=pyinfra_config)
queue_manager = QueueManager(pyinfra_config)
queue_manager.start_consuming(process_payload)
```
`process_data` should expect a dict (json) or bytes (pdf) as input and should return a list of results.
## Scripts
### Run pyinfra locally
**Shell 1**: Start minio and rabbitmq containers
```bash
$ cd tests && docker-compose up
$ cd tests && docker compose up
```
**Shell 2**: Start pyinfra with callback mock
```bash
$ python scripts/start_pyinfra.py
$ python scripts/start_pyinfra.py
```
**Shell 3**: Upload dummy content on storage and publish message
```bash
$ python scripts/send_request.py
```
## Tests
Running all tests take a bit longer than you are probably used to, because among other things the required startup times are
quite high for docker-compose dependent tests. This is why the tests are split into two parts. The first part contains all
tests that do not require docker-compose and the second part contains all tests that require docker-compose.
Per default, only the first part is executed, but when releasing a new version, all tests should be executed.
Tests require a running minio and rabbitmq container, meaning you have to run `docker compose up` in the tests folder
before running the tests.

1754
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,106 +0,0 @@
from os import environ
from typing import Union
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
def read_from_environment(environment_variable_name, default_value):
return environ.get(environment_variable_name, default_value)
def normalize_bool(value: Union[str, bool]):
return value if isinstance(value, bool) else value in ["True", "true"]
class Config:
def __init__(self):
# Logging level for service logger
self.logging_level_root = read_from_environment("LOGGING_LEVEL_ROOT", "DEBUG")
# Enables Prometheus monitoring
self.monitoring_enabled = normalize_bool(read_from_environment("MONITORING_ENABLED", True))
# Prometheus metric prefix, per convention '{product_name}_{service_name}_{parameter}'
# In the current implementation, the results of a service define the parameter that is monitored,
# i.e. analysis result per image means processing time per image is monitored.
# TODO: add validator since some characters like '-' are not allowed by python prometheus
self.prometheus_metric_prefix = read_from_environment(
"PROMETHEUS_METRIC_PREFIX", "redactmanager_research_service_parameter"
)
# Prometheus webserver address and port
self.prometheus_host = "0.0.0.0"
self.prometheus_port = 8080
# RabbitMQ host address
self.rabbitmq_host = read_from_environment("RABBITMQ_HOST", "localhost")
# RabbitMQ host port
self.rabbitmq_port = read_from_environment("RABBITMQ_PORT", "5672")
# RabbitMQ username
self.rabbitmq_username = read_from_environment("RABBITMQ_USERNAME", "user")
# RabbitMQ password
self.rabbitmq_password = read_from_environment("RABBITMQ_PASSWORD", "bitnami")
# Controls AMQP heartbeat timeout in seconds
self.rabbitmq_heartbeat = int(read_from_environment("RABBITMQ_HEARTBEAT", 60))
# Controls AMQP connection sleep timer in seconds
# important for heartbeat to come through while main function runs on other thread
self.rabbitmq_connection_sleep = int(read_from_environment("RABBITMQ_CONNECTION_SLEEP", 5))
# Queue name for requests to the service
self.request_queue = read_from_environment("REQUEST_QUEUE", "request_queue")
# Queue name for responses by service
self.response_queue = read_from_environment("RESPONSE_QUEUE", "response_queue")
# Queue name for failed messages
self.dead_letter_queue = read_from_environment("DEAD_LETTER_QUEUE", "dead_letter_queue")
# The type of storage to use {s3, azure}
self.storage_backend = read_from_environment("STORAGE_BACKEND", "s3")
# The bucket / container to pull files specified in queue requests from
if self.storage_backend == "s3":
self.storage_bucket = read_from_environment("STORAGE_BUCKET_NAME", "redaction")
else:
self.storage_bucket = read_from_environment("STORAGE_AZURECONTAINERNAME", "redaction")
# S3 connection security flag and endpoint
storage_address = read_from_environment("STORAGE_ENDPOINT", "http://127.0.0.1:9000")
self.storage_secure_connection, self.storage_endpoint = validate_and_parse_s3_endpoint(storage_address)
# User for s3 storage
self.storage_key = read_from_environment("STORAGE_KEY", "root")
# Password for s3 storage
self.storage_secret = read_from_environment("STORAGE_SECRET", "password")
# Region for s3 storage
self.storage_region = read_from_environment("STORAGE_REGION", "eu-central-1")
# Connection string for Azure storage
self.storage_azureconnectionstring = read_from_environment(
"STORAGE_AZURECONNECTIONSTRING",
"DefaultEndpointsProtocol=...",
)
# Allowed file types for downloaded and uploaded storage objects that get processed by the service
self.allowed_file_types = ["json", "pdf"]
self.allowed_compression_types = ["gz"]
self.allowed_processing_parameters = ["operation"]
# config for x-tenant-endpoint to receive storage connection information per tenant
self.tenant_decryption_public_key = read_from_environment("TENANT_PUBLIC_KEY", "redaction")
self.tenant_endpoint = read_from_environment("TENANT_ENDPOINT", "http://tenant-user-management:8081/internal-api/tenants")
# Value to see if we should write a consumer token to a file
self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False")
def get_config() -> Config:
return Config()

133
pyinfra/config/loader.py Normal file
View File

@ -0,0 +1,133 @@
import argparse
import os
from functools import partial
from pathlib import Path
from typing import Union
from dynaconf import Dynaconf, ValidationError, Validator
from funcy import lflatten
from kn_utils.logging import logger
# This path is ment for testing purposes and convenience. It probably won't reflect the actual root path when pyinfra is
# installed as a package, so don't use it in production code, but define your own root path as described in load config.
local_pyinfra_root_path = Path(__file__).parents[2]
def load_settings(
settings_path: Union[str, Path, list] = "config/",
root_path: Union[str, Path] = None,
validators: list[Validator] = None,
):
"""Load settings from .toml files, .env and environment variables. Also ensures a ROOT_PATH environment variable is
set. If ROOT_PATH is not set and no root_path argument is passed, the current working directory is used as root.
Settings paths can be a single .toml file, a folder containing .toml files or a list of .toml files and folders.
If a ROOT_PATH environment variable is set, it is not overwritten by the root_path argument.
If a folder is passed, all .toml files in the folder are loaded. If settings path is None, only .env and
environment variables are loaded. If settings_path are relative paths, they are joined with the root_path argument.
"""
root_path = get_or_set_root_path(root_path)
validators = validators or get_pyinfra_validators()
settings_files = normalize_to_settings_files(settings_path, root_path)
settings = Dynaconf(
load_dotenv=True,
envvar_prefix=False,
settings_files=settings_files,
)
validate_settings(settings, validators)
logger.info("Settings loaded and validated.")
return settings
def normalize_to_settings_files(settings_path: Union[str, Path, list], root_path: Union[str, Path]):
if settings_path is None:
logger.info("No settings path specified, only loading .env end ENVs.")
settings_files = []
elif isinstance(settings_path, str) or isinstance(settings_path, Path):
settings_files = [settings_path]
elif isinstance(settings_path, list):
settings_files = settings_path
else:
raise ValueError(f"Invalid settings path: {settings_path=}")
settings_files = lflatten(map(partial(_normalize_and_verify, root_path=root_path), settings_files))
logger.debug(f"Normalized settings files: {settings_files}")
return settings_files
def _normalize_and_verify(settings_path: Path, root_path: Path):
settings_path = Path(settings_path)
root_path = Path(root_path)
if not settings_path.is_absolute():
logger.debug(f"Settings path is not absolute, joining with root path: {root_path}")
settings_path = root_path / settings_path
if settings_path.is_dir():
logger.debug(f"Settings path is a directory, loading all .toml files in the directory: {settings_path}")
settings_files = list(settings_path.glob("*.toml"))
elif settings_path.is_file():
logger.debug(f"Settings path is a file, loading specified file: {settings_path}")
settings_files = [settings_path]
else:
raise ValueError(f"Invalid settings path: {settings_path=}, {root_path=}")
return settings_files
def get_or_set_root_path(root_path: Union[str, Path] = None):
env_root_path = os.environ.get("ROOT_PATH")
if env_root_path:
root_path = env_root_path
logger.debug(f"'ROOT_PATH' environment variable is set to {root_path}.")
elif root_path:
logger.info(f"'ROOT_PATH' environment variable is not set, setting to {root_path}.")
os.environ["ROOT_PATH"] = str(root_path)
else:
root_path = Path.cwd()
logger.info(f"'ROOT_PATH' environment variable is not set, defaulting to working directory {root_path}.")
os.environ["ROOT_PATH"] = str(root_path)
return root_path
def get_pyinfra_validators():
import pyinfra.config.validators
return lflatten(
validator for validator in pyinfra.config.validators.__dict__.values() if isinstance(validator, list)
)
def validate_settings(settings: Dynaconf, validators):
settings_valid = True
for validator in validators:
try:
validator.validate(settings)
except ValidationError as e:
settings_valid = False
logger.warning(e)
if not settings_valid:
raise ValidationError("Settings validation failed.")
logger.debug("Settings validated.")
def parse_settings_path():
parser = argparse.ArgumentParser()
parser.add_argument(
"settings_path",
help="Path to settings file(s) or folder(s). Must be .toml file(s) or a folder(s) containing .toml files.",
nargs="+",
)
return parser.parse_args().settings_path

View File

@ -0,0 +1,51 @@
from dynaconf import Validator
queue_manager_validators = [
Validator("rabbitmq.host", must_exist=True, is_type_of=str),
Validator("rabbitmq.port", must_exist=True, is_type_of=int),
Validator("rabbitmq.username", must_exist=True, is_type_of=str),
Validator("rabbitmq.password", must_exist=True, is_type_of=str),
Validator("rabbitmq.heartbeat", must_exist=True, is_type_of=int),
Validator("rabbitmq.connection_sleep", must_exist=True, is_type_of=int),
Validator("rabbitmq.input_queue", must_exist=True, is_type_of=str),
Validator("rabbitmq.output_queue", must_exist=True, is_type_of=str),
Validator("rabbitmq.dead_letter_queue", must_exist=True, is_type_of=str),
]
azure_storage_validators = [
Validator("storage.azure.connection_string", must_exist=True, is_type_of=str),
Validator("storage.azure.container", must_exist=True, is_type_of=str),
]
s3_storage_validators = [
Validator("storage.s3.endpoint", must_exist=True, is_type_of=str),
Validator("storage.s3.key", must_exist=True, is_type_of=str),
Validator("storage.s3.secret", must_exist=True, is_type_of=str),
Validator("storage.s3.region", must_exist=True, is_type_of=str),
Validator("storage.s3.bucket", must_exist=True, is_type_of=str),
]
storage_validators = [
Validator("storage.backend", must_exist=True, is_type_of=str),
]
multi_tenant_storage_validators = [
Validator("storage.tenant_server.endpoint", must_exist=True, is_type_of=str),
Validator("storage.tenant_server.public_key", must_exist=True, is_type_of=str),
]
prometheus_validators = [
Validator("metrics.prometheus.prefix", must_exist=True, is_type_of=str),
Validator("metrics.prometheus.enabled", must_exist=True, is_type_of=bool),
]
webserver_validators = [
Validator("webserver.host", must_exist=True, is_type_of=str),
Validator("webserver.port", must_exist=True, is_type_of=int),
]
opentelemetry_validators = [
Validator("tracing.opentelemetry.endpoint", must_exist=True, is_type_of=str),
Validator("tracing.opentelemetry.service_name", must_exist=True, is_type_of=str),
]

55
pyinfra/examples.py Normal file
View File

@ -0,0 +1,55 @@
from dynaconf import Dynaconf
from fastapi import FastAPI
from kn_utils.logging import logger
from pyinfra.config.loader import get_pyinfra_validators, validate_settings
from pyinfra.queue.callback import Callback
from pyinfra.queue.manager import QueueManager
from pyinfra.utils.opentelemetry import instrument_pika, setup_trace, instrument_app
from pyinfra.webserver.prometheus import (
add_prometheus_endpoint,
make_prometheus_processing_time_decorator_from_settings,
)
from pyinfra.webserver.utils import (
add_health_check_endpoint,
create_webserver_thread_from_settings,
)
def start_standard_queue_consumer(
callback: Callback,
settings: Dynaconf,
app: FastAPI = None,
):
"""Default serving logic for research services.
Supplies /health, /ready and /prometheus endpoints (if enabled). The callback is monitored for processing time per
message. Also traces the queue messages via openTelemetry (if enabled).
Workload is received via queue messages and processed by the callback function (see pyinfra.queue.callback for
callbacks).
"""
validate_settings(settings, get_pyinfra_validators())
logger.info(f"Starting webserver and queue consumer...")
app = app or FastAPI()
queue_manager = QueueManager(settings)
if settings.metrics.prometheus.enabled:
logger.info(f"Prometheus metrics enabled.")
app = add_prometheus_endpoint(app)
callback = make_prometheus_processing_time_decorator_from_settings(settings)(callback)
if settings.tracing.opentelemetry.enabled:
logger.info(f"OpenTelemetry tracing enabled.")
setup_trace(settings)
instrument_pika()
instrument_app(app)
app = add_health_check_endpoint(app, queue_manager.is_ready)
webserver_thread = create_webserver_thread_from_settings(app, settings)
webserver_thread.start()
queue_manager.start_consuming(callback)

View File

@ -1,5 +0,0 @@
class ProcessingFailure(RuntimeError):
pass
class UnknownStorageBackend(Exception):
pass

View File

@ -1,36 +0,0 @@
import sys
from kn_utils.logging import logger
from pathlib import Path
from pyinfra.queue.queue_manager import token_file_name
def check_token_file():
"""
Checks if the token file of the QueueManager exists and is not empty, i.e. the queue manager has been started.
NOTE: This function suppresses all Exception's.
Returns True if the queue manager has been started, False otherwise
"""
try:
token_file_path = Path(token_file_name())
if token_file_path.exists():
with token_file_path.open(mode="r", encoding="utf8") as token_file:
contents = token_file.read().strip()
return contents != ""
# We intentionally do not handle exception here, since we're only using this in a short script.
# Take care to expand this if the intended use changes
except Exception as err:
logger.warning(f"{err}: Caught exception when reading from token file", exc_info=True)
return False
def run_checks():
if check_token_file():
sys.exit(0)
else:
sys.exit(1)

View File

@ -1,57 +0,0 @@
from funcy import identity
from operator import attrgetter
from prometheus_client import Summary, start_http_server, CollectorRegistry
from time import time
from typing import Callable, Any, Sized
from pyinfra.config import Config
class PrometheusMonitor:
def __init__(self, prefix: str, host: str, port: int):
"""Register the monitoring metrics and start a webserver where they can be scraped at the endpoint
http://{host}:{port}/prometheus
Args:
prefix: should per convention consist of {product_name}_{service_name}_{parameter_to_monitor}
parameter_to_monitor is defined by the result of the processing service.
"""
self.registry = CollectorRegistry()
self.entity_processing_time_sum = Summary(
f"{prefix}_processing_time", "Summed up average processing time per entity observed", registry=self.registry
)
start_http_server(port, host, self.registry)
def __call__(self, process_fn: Callable) -> Callable:
"""Monitor the runtime of a function and update the registered metric with the average runtime per resulting
element.
"""
return self._add_result_monitoring(process_fn)
def _add_result_monitoring(self, process_fn: Callable):
def inner(data: Any, **kwargs):
start = time()
result: Sized = process_fn(data, **kwargs)
runtime = time() - start
if not result:
return result
processing_time_per_entity = runtime / len(result)
self.entity_processing_time_sum.observe(processing_time_per_entity)
return result
return inner
def get_monitor_from_config(config: Config) -> Callable:
if config.monitoring_enabled:
return PrometheusMonitor(*attrgetter("prometheus_metric_prefix", "prometheus_host", "prometheus_port")(config))
else:
return identity

View File

@ -1,199 +0,0 @@
from dataclasses import dataclass
from functools import singledispatch, partial
from funcy import project, complement
from itertools import chain
from operator import itemgetter
from typing import Union, Sized, Callable, List
from pyinfra.config import Config
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
@dataclass
class QueueMessagePayload:
"""Default one-to-one payload, where the message contains the absolute file paths for the target and response files,
that have to be acquired from the storage."""
target_file_path: str
response_file_path: str
target_file_type: Union[str, None]
target_compression_type: Union[str, None]
response_file_type: Union[str, None]
response_compression_type: Union[str, None]
x_tenant_id: Union[str, None]
processing_kwargs: dict
@dataclass
class LegacyQueueMessagePayload(QueueMessagePayload):
"""Legacy one-to-one payload, where the message contains the dossier and file ids, and the file extensions that have
to be used to construct the absolute file paths for the target and response files, that have to be acquired from the
storage."""
dossier_id: str
file_id: str
target_file_extension: str
response_file_extension: str
class QueueMessagePayloadParser:
def __init__(self, payload_matcher2parse_strategy: dict):
self.payload_matcher2parse_strategy = payload_matcher2parse_strategy
def __call__(self, payload: dict) -> QueueMessagePayload:
for payload_matcher, parse_strategy in self.payload_matcher2parse_strategy.items():
if payload_matcher(payload):
return parse_strategy(payload)
def get_queue_message_payload_parser(config: Config) -> QueueMessagePayloadParser:
file_extension_parser = make_file_extension_parser(config.allowed_file_types, config.allowed_compression_types)
payload_matcher2parse_strategy = get_payload_matcher2parse_strategy(
file_extension_parser, config.allowed_processing_parameters
)
return QueueMessagePayloadParser(payload_matcher2parse_strategy)
def get_payload_matcher2parse_strategy(parse_file_extensions: Callable, allowed_processing_parameters: List[str]):
return {
is_legacy_payload: partial(
parse_legacy_queue_message_payload,
parse_file_extensions=parse_file_extensions,
allowed_processing_parameters=allowed_processing_parameters,
),
complement(is_legacy_payload): partial(
parse_queue_message_payload,
parse_file_extensions=parse_file_extensions,
allowed_processing_parameters=allowed_processing_parameters,
),
}
def is_legacy_payload(payload: dict) -> bool:
return {"dossierId", "fileId", "targetFileExtension", "responseFileExtension"}.issubset(payload.keys())
def parse_queue_message_payload(
payload: dict,
parse_file_extensions: Callable,
allowed_processing_parameters: List[str],
) -> QueueMessagePayload:
target_file_path, response_file_path = itemgetter("targetFilePath", "responseFilePath")(payload)
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
map(parse_file_extensions, [target_file_path, response_file_path])
)
x_tenant_id = payload.get("X-TENANT-ID")
processing_kwargs = project(payload, allowed_processing_parameters)
return QueueMessagePayload(
target_file_path=target_file_path,
response_file_path=response_file_path,
target_file_type=target_file_type,
target_compression_type=target_compression_type,
response_file_type=response_file_type,
response_compression_type=response_compression_type,
x_tenant_id=x_tenant_id,
processing_kwargs=processing_kwargs,
)
def parse_legacy_queue_message_payload(
payload: dict,
parse_file_extensions: Callable,
allowed_processing_parameters: List[str],
) -> LegacyQueueMessagePayload:
dossier_id, file_id, target_file_extension, response_file_extension = itemgetter(
"dossierId", "fileId", "targetFileExtension", "responseFileExtension"
)(payload)
target_file_path = f"{dossier_id}/{file_id}.{target_file_extension}"
response_file_path = f"{dossier_id}/{file_id}.{response_file_extension}"
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
map(parse_file_extensions, [target_file_extension, response_file_extension])
)
x_tenant_id = payload.get("X-TENANT-ID")
processing_kwargs = project(payload, allowed_processing_parameters)
return LegacyQueueMessagePayload(
dossier_id=dossier_id,
file_id=file_id,
x_tenant_id=x_tenant_id,
target_file_extension=target_file_extension,
response_file_extension=response_file_extension,
target_file_type=target_file_type,
target_compression_type=target_compression_type,
response_file_type=response_file_type,
response_compression_type=response_compression_type,
target_file_path=target_file_path,
response_file_path=response_file_path,
processing_kwargs=processing_kwargs,
)
@singledispatch
def format_service_processing_result_for_storage(payload: QueueMessagePayload, result: Sized) -> dict:
raise NotImplementedError("Unsupported payload type")
@format_service_processing_result_for_storage.register(LegacyQueueMessagePayload)
def _(payload: LegacyQueueMessagePayload, result: Sized) -> dict:
processing_kwargs = payload.processing_kwargs or {}
x_tenant_id = {"X-TENANT-ID": payload.x_tenant_id} if payload.x_tenant_id else {}
return {
"dossierId": payload.dossier_id,
"fileId": payload.file_id,
"targetFileExtension": payload.target_file_extension,
"responseFileExtension": payload.response_file_extension,
**x_tenant_id,
**processing_kwargs,
"data": result,
}
@format_service_processing_result_for_storage.register(QueueMessagePayload)
def _(payload: QueueMessagePayload, result: Sized) -> dict:
processing_kwargs = payload.processing_kwargs or {}
x_tenant_id = {"X-TENANT-ID": payload.x_tenant_id} if payload.x_tenant_id else {}
return {
"targetFilePath": payload.target_file_path,
"responseFilePath": payload.response_file_path,
**x_tenant_id,
**processing_kwargs,
"data": result,
}
@singledispatch
def format_to_queue_message_response_body(queue_message_payload: QueueMessagePayload) -> dict:
raise NotImplementedError("Unsupported payload type")
@format_to_queue_message_response_body.register(LegacyQueueMessagePayload)
def _(payload: LegacyQueueMessagePayload) -> dict:
processing_kwargs = payload.processing_kwargs or {}
x_tenant_id = {"X-TENANT-ID": payload.x_tenant_id} if payload.x_tenant_id else {}
return {"dossierId": payload.dossier_id, "fileId": payload.file_id, **x_tenant_id, **processing_kwargs}
@format_to_queue_message_response_body.register(QueueMessagePayload)
def _(payload: QueueMessagePayload) -> dict:
processing_kwargs = payload.processing_kwargs or {}
x_tenant_id = {"X-TENANT-ID": payload.x_tenant_id} if payload.x_tenant_id else {}
return {
"targetFilePath": payload.target_file_path,
"responseFilePath": payload.response_file_path,
**x_tenant_id,
**processing_kwargs,
}

View File

@ -1,97 +0,0 @@
from kn_utils.logging import logger
from dataclasses import asdict
from typing import Callable, List
from pyinfra.config import get_config, Config
from pyinfra.payload_processing.monitor import get_monitor_from_config
from pyinfra.payload_processing.payload import (
QueueMessagePayloadParser,
get_queue_message_payload_parser,
format_service_processing_result_for_storage,
format_to_queue_message_response_body,
QueueMessagePayload,
)
from pyinfra.storage.storage import make_downloader, make_uploader
from pyinfra.storage.storage_provider import StorageProvider
class PayloadProcessor:
def __init__(
self,
storage_provider: StorageProvider,
payload_parser: QueueMessagePayloadParser,
data_processor: Callable,
):
"""Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps.
Args:
storage_provider: Storage manager that connects to the storage, using the tenant id if provided
payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object
data_processor: The analysis function to be called with the downloaded file
NOTE: The result of the analysis function has to be an instance of `Sized`, e.g. a dict or a list to be
able to upload it and to be able to monitor the processing time.
"""
self.parse_payload = payload_parser
self.provide_storage = storage_provider
self.process_data = data_processor
def __call__(self, queue_message_payload: dict) -> dict:
"""Processes a queue message payload.
The steps executed are:
1. Download the file specified in the message payload from the storage
2. Process the file with the analysis function
3. Upload the result to the storage
4. Return the payload for a response queue message
Args:
queue_message_payload: The payload of a queue message. The payload is expected to be a dict with the
following keys:
targetFilePath, responseFilePath
OR
dossierId, fileId, targetFileExtension, responseFileExtension
Returns:
The payload for a response queue message, containing only the request payload.
"""
return self._process(queue_message_payload)
def _process(self, queue_message_payload: dict) -> dict:
payload: QueueMessagePayload = self.parse_payload(queue_message_payload)
logger.info(f"Processing {payload.__class__.__name__} ...")
logger.debug(f"Payload contents: {asdict(payload)} ...")
storage, storage_info = self.provide_storage(payload.x_tenant_id)
download_file_to_process = make_downloader(
storage, storage_info.bucket_name, payload.target_file_type, payload.target_compression_type
)
upload_processing_result = make_uploader(
storage, storage_info.bucket_name, payload.response_file_type, payload.response_compression_type
)
data = download_file_to_process(payload.target_file_path)
result: List[dict] = self.process_data(data, **payload.processing_kwargs)
formatted_result = format_service_processing_result_for_storage(payload, result)
upload_processing_result(payload.response_file_path, formatted_result)
return format_to_queue_message_response_body(payload)
def make_payload_processor(data_processor: Callable, config: Config = None) -> PayloadProcessor:
"""Creates a payload processor."""
config = config or get_config()
storage_provider = StorageProvider(config)
monitor = get_monitor_from_config(config)
payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config)
data_processor = monitor(data_processor)
return PayloadProcessor(
storage_provider,
payload_parser,
data_processor,
)

39
pyinfra/queue/callback.py Normal file
View File

@ -0,0 +1,39 @@
from typing import Callable, Union
from dynaconf import Dynaconf
from kn_utils.logging import logger
from pyinfra.storage.connection import get_storage
from pyinfra.storage.utils import (
download_data_as_specified_in_message,
upload_data_as_specified_in_message,
)
DataProcessor = Callable[[Union[dict, bytes], dict], Union[dict, list, str]]
Callback = Callable[[dict], dict]
def make_download_process_upload_callback(data_processor: DataProcessor, settings: Dynaconf) -> Callback:
"""Default callback for processing queue messages.
Data will be downloaded from the storage as specified in the message. If a tenant id is specified, the storage
will be configured to use that tenant id, otherwise the storage is configured as specified in the settings.
The data is the passed to the dataprocessor, together with the message. The dataprocessor should return a
json serializable object. This object is then uploaded to the storage as specified in the message. The response
message is just the original message.
"""
def inner(queue_message_payload: dict) -> dict:
logger.info(f"Processing payload with download-process-upload callback...")
storage = get_storage(settings, queue_message_payload.get("X-TENANT-ID"))
data = download_data_as_specified_in_message(storage, queue_message_payload)
result = data_processor(data, queue_message_payload)
upload_data_as_specified_in_message(storage, queue_message_payload, result)
return queue_message_payload
return inner

View File

@ -1,40 +0,0 @@
import json
import pika
import pika.exceptions
from pyinfra.config import Config
from pyinfra.queue.queue_manager import QueueManager
class DevelopmentQueueManager(QueueManager):
"""Extends the queue manger with additional functionality that is needed for tests and scripts,
but not in production, such as publishing messages.
"""
def __init__(self, config: Config):
super().__init__(config)
self._open_channel()
def publish_request(self, message: dict, properties: pika.BasicProperties = None):
message_encoded = json.dumps(message).encode("utf-8")
self._channel.basic_publish(
"",
self._input_queue,
properties=properties,
body=message_encoded,
)
def get_response(self):
return self._channel.basic_get(self._output_queue)
def clear_queues(self):
"""purge input & output queues"""
try:
self._channel.queue_purge(self._input_queue)
self._channel.queue_purge(self._output_queue)
except pika.exceptions.ChannelWrongStateError:
pass
def close_channel(self):
self._channel.close()

197
pyinfra/queue/manager.py Normal file
View File

@ -0,0 +1,197 @@
import atexit
import concurrent.futures
import json
import logging
import signal
import sys
from typing import Callable, Union
import pika
import pika.exceptions
from dynaconf import Dynaconf
from kn_utils.logging import logger
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
from retry import retry
from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import queue_manager_validators
pika_logger = logging.getLogger("pika")
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
MessageProcessor = Callable[[dict], dict]
class QueueManager:
def __init__(self, settings: Dynaconf):
validate_settings(settings, queue_manager_validators)
self.input_queue = settings.rabbitmq.input_queue
self.output_queue = settings.rabbitmq.output_queue
self.dead_letter_queue = settings.rabbitmq.dead_letter_queue
self.connection_parameters = self.create_connection_parameters(settings)
self.connection: Union[BlockingConnection, None] = None
self.channel: Union[BlockingChannel, None] = None
self.connection_sleep = settings.rabbitmq.connection_sleep
atexit.register(self.stop_consuming)
signal.signal(signal.SIGTERM, self._handle_stop_signal)
signal.signal(signal.SIGINT, self._handle_stop_signal)
@staticmethod
def create_connection_parameters(settings: Dynaconf):
credentials = pika.PlainCredentials(username=settings.rabbitmq.username, password=settings.rabbitmq.password)
pika_connection_params = {
"host": settings.rabbitmq.host,
"port": settings.rabbitmq.port,
"credentials": credentials,
"heartbeat": settings.rabbitmq.heartbeat,
}
return pika.ConnectionParameters(**pika_connection_params)
@retry(tries=3, delay=5, jitter=(1, 3), logger=logger)
def establish_connection(self):
# TODO: set sensible retry parameters
if self.connection and self.connection.is_open:
logger.debug("Connection to RabbitMQ already established.")
return
logger.info("Establishing connection to RabbitMQ...")
self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
logger.debug("Opening channel...")
self.channel = self.connection.channel()
self.channel.basic_qos(prefetch_count=1)
args = {
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.dead_letter_queue,
}
self.channel.queue_declare(self.input_queue, arguments=args, auto_delete=False, durable=True)
self.channel.queue_declare(self.output_queue, arguments=args, auto_delete=False, durable=True)
logger.info("Connection to RabbitMQ established, channel open.")
def is_ready(self):
self.establish_connection()
return self.channel.is_open
@retry(exceptions=pika.exceptions.AMQPConnectionError, tries=3, delay=5, jitter=(1, 3), logger=logger)
def start_consuming(self, message_processor: Callable):
on_message_callback = self._make_on_message_callback(message_processor)
try:
self.establish_connection()
self.channel.basic_consume(self.input_queue, on_message_callback)
self.channel.start_consuming()
except Exception:
logger.error("An unexpected error occurred while consuming messages. Consuming will stop.", exc_info=True)
raise
finally:
self.stop_consuming()
def stop_consuming(self):
if self.channel and self.channel.is_open:
logger.info("Stopping consuming...")
self.channel.stop_consuming()
logger.info("Closing channel...")
self.channel.close()
if self.connection and self.connection.is_open:
logger.info("Closing connection to RabbitMQ...")
self.connection.close()
def publish_message_to_input_queue(self, message: Union[str, bytes, dict], properties: pika.BasicProperties = None):
if isinstance(message, str):
message = message.encode("utf-8")
elif isinstance(message, dict):
message = json.dumps(message).encode("utf-8")
self.establish_connection()
self.channel.basic_publish(
"",
self.input_queue,
properties=properties,
body=message,
)
logger.info(f"Published message to queue {self.input_queue}.")
def purge_queues(self):
self.establish_connection()
try:
self.channel.queue_purge(self.input_queue)
self.channel.queue_purge(self.output_queue)
logger.info("Queues purged.")
except pika.exceptions.ChannelWrongStateError:
pass
def get_message_from_output_queue(self):
self.establish_connection()
return self.channel.basic_get(self.output_queue, auto_ack=True)
def _make_on_message_callback(self, message_processor: MessageProcessor):
def process_message_body_and_await_result(unpacked_message_body):
# Processing the message in a separate thread is necessary for the main thread pika client to be able to
# process data events (e.g. heartbeats) while the message is being processed.
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.info("Processing payload in separate thread.")
future = thread_pool_executor.submit(message_processor, unpacked_message_body)
# TODO: This block is probably not necessary, but kept since the implications of removing it are
# unclear. Remove it in a future iteration where less changes are being made to the code base.
while future.running():
logger.debug("Waiting for payload processing to finish...")
self.connection.sleep(self.connection_sleep)
return future.result()
def on_message_callback(channel, method, properties, body):
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
if method.redelivered:
logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.")
channel.basic_nack(method.delivery_tag, requeue=False)
return
if body.decode("utf-8") == "STOP":
logger.info(f"Received stop signal, stopping consuming...")
channel.basic_ack(delivery_tag=method.delivery_tag)
self.stop_consuming()
return
try:
filtered_message_headers = (
{k: v for k, v in properties.headers.items() if k.lower().startswith("x-")}
if properties.headers
else {}
)
logger.debug(f"Processing message with {filtered_message_headers=}.")
result: dict = (
process_message_body_and_await_result({**json.loads(body), **filtered_message_headers}) or {}
)
channel.basic_publish(
"",
self.output_queue,
json.dumps(result).encode(),
properties=pika.BasicProperties(headers=filtered_message_headers),
)
logger.info(f"Published result to queue {self.output_queue}.")
channel.basic_ack(delivery_tag=method.delivery_tag)
logger.debug(f"Message with {method.delivery_tag=} acknowledged.")
except Exception:
logger.warning(f"Failed to process message with {method.delivery_tag=}, declining...", exc_info=True)
channel.basic_nack(method.delivery_tag, requeue=False)
raise
return on_message_callback
def _handle_stop_signal(self, signum, *args, **kwargs):
logger.info(f"Received signal {signum}, stopping consuming...")
self.stop_consuming()
sys.exit(0)

View File

@ -1,205 +0,0 @@
import atexit
import concurrent.futures
import json
import logging
import pika
import pika.exceptions
import signal
from kn_utils.logging import logger
from pathlib import Path
from pika.adapters.blocking_connection import BlockingChannel
from pyinfra.config import Config
from pyinfra.exception import ProcessingFailure
from pyinfra.payload_processing.processor import PayloadProcessor
from pyinfra.utils.dict import safe_project
CONFIG = Config()
pika_logger = logging.getLogger("pika")
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
def get_connection_params(config: Config) -> pika.ConnectionParameters:
"""creates pika connection params from pyinfra.Config class
Args:
config (pyinfra.Config): standard pyinfra config class
Returns:
pika.ConnectionParameters: standard pika connection param object
"""
credentials = pika.PlainCredentials(username=config.rabbitmq_username, password=config.rabbitmq_password)
pika_connection_params = {
"host": config.rabbitmq_host,
"port": config.rabbitmq_port,
"credentials": credentials,
"heartbeat": config.rabbitmq_heartbeat,
}
return pika.ConnectionParameters(**pika_connection_params)
def _get_n_previous_attempts(props):
return 0 if props.headers is None else props.headers.get("x-retry-count", 0)
def token_file_name():
"""create filepath
Returns:
joblib.Path: filepath
"""
token_file_path = Path("/tmp") / "consumer_token.txt"
return token_file_path
class QueueManager:
"""Handle RabbitMQ message reception & delivery"""
def __init__(self, config: Config):
self._input_queue = config.request_queue
self._output_queue = config.response_queue
self._dead_letter_queue = config.dead_letter_queue
# controls how often we send out a life signal
self._heartbeat = config.rabbitmq_heartbeat
# controls for how long we only process data events (e.g. heartbeats),
# while the queue is blocked and we process the given callback function
self._connection_sleep = config.rabbitmq_connection_sleep
self._write_token = config.write_consumer_token == "True"
self._set_consumer_token(None)
self._connection_params = get_connection_params(config)
self._connection = pika.BlockingConnection(parameters=self._connection_params)
self._channel: BlockingChannel
# necessary to pods can be terminated/restarted in K8s/docker
atexit.register(self.stop_consuming)
signal.signal(signal.SIGTERM, self._handle_stop_signal)
signal.signal(signal.SIGINT, self._handle_stop_signal)
def _set_consumer_token(self, token_value):
self._consumer_token = token_value
if self._write_token:
token_file_path = token_file_name()
with token_file_path.open(mode="w", encoding="utf8") as token_file:
text = token_value if token_value is not None else ""
token_file.write(text)
def _open_channel(self):
self._channel = self._connection.channel()
self._channel.basic_qos(prefetch_count=1)
args = {
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self._dead_letter_queue,
}
self._channel.queue_declare(self._input_queue, arguments=args, auto_delete=False, durable=True)
self._channel.queue_declare(self._output_queue, arguments=args, auto_delete=False, durable=True)
def start_consuming(self, process_payload: PayloadProcessor):
"""consumption handling
- standard callback handling is enforced through wrapping process_message_callback in _create_queue_callback
(implements threading to support heartbeats)
- initially sets consumer token to None
- tries to
- open channels
- set consumer token to basic_consume, passing in the standard callback and input queue name
- calls pika start_consuming method on the channels
- catches all Exceptions & stops consuming + closes channels
Args:
process_payload (Callable): function passed to the queue manager, configured by implementing service
"""
callback = self._create_queue_callback(process_payload)
self._set_consumer_token(None)
try:
self._open_channel()
self._set_consumer_token(self._channel.basic_consume(self._input_queue, callback))
logger.info(f"Registered with consumer-tag: {self._consumer_token}")
self._channel.start_consuming()
except Exception:
logger.error(
"An unexpected exception occurred while consuming messages. Consuming will stop.", exc_info=True
)
raise
finally:
self.stop_consuming()
self._connection.close()
def stop_consuming(self):
if self._consumer_token and self._connection:
logger.info(f"Cancelling subscription for consumer-tag {self._consumer_token}")
self._channel.stop_consuming(self._consumer_token)
self._set_consumer_token(None)
def _handle_stop_signal(self, signal_number, _stack_frame, *args, **kwargs):
logger.info(f"Received signal {signal_number}")
self.stop_consuming()
def _create_queue_callback(self, process_payload: PayloadProcessor):
def process_message_body_and_await_result(unpacked_message_body):
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.debug("Processing payload in separate thread.")
future = thread_pool_executor.submit(process_payload, unpacked_message_body)
while future.running():
logger.debug("Waiting for payload processing to finish...")
self._connection.sleep(float(self._connection_sleep))
try:
return future.result()
except Exception as err:
raise ProcessingFailure(f"QueueMessagePayload processing failed: {repr(err)}") from err
def acknowledge_message_and_publish_response(frame, headers, response_body):
response_properties = pika.BasicProperties(headers=headers) if headers else None
self._channel.basic_publish("", self._output_queue, json.dumps(response_body).encode(), response_properties)
logger.debug(f"Result published, acknowledging incoming message with delivery_tag {frame.delivery_tag}.")
self._channel.basic_ack(frame.delivery_tag)
def callback(_channel, frame, properties, body):
logger.info(f"Received message from queue with delivery_tag {frame.delivery_tag}.")
logger.debug(f"Message headers: {properties.headers}")
# Only try to process each message once. Re-queueing will be handled by the dead-letter-exchange. This
# prevents endless retries on messages that are impossible to process.
if frame.redelivered:
logger.info(
f"Aborting message processing for delivery_tag {frame.delivery_tag} due to it being redelivered.",
)
self._channel.basic_nack(frame.delivery_tag, requeue=False)
return
try:
logger.debug(f"Processing {frame}, {properties}, {body}")
filtered_message_headers = safe_project(properties.headers, ["X-TENANT-ID"])
message_body = {**json.loads(body), **filtered_message_headers}
processing_result = process_message_body_and_await_result(message_body)
logger.info(
f"Processed message with delivery_tag {frame.delivery_tag}, publishing result to result-queue."
)
acknowledge_message_and_publish_response(frame, filtered_message_headers, processing_result)
except ProcessingFailure as err:
logger.info(f"Processing message with delivery_tag {frame.delivery_tag} failed, declining.")
logger.exception(err)
self._channel.basic_nack(frame.delivery_tag, requeue=False)
except Exception:
n_attempts = _get_n_previous_attempts(properties) + 1
logger.warning(f"Failed to process message, {n_attempts}", exc_info=True)
self._channel.basic_nack(frame.delivery_tag, requeue=False)
raise
return callback

View File

@ -0,0 +1,89 @@
from functools import lru_cache
import requests
from dynaconf import Dynaconf
from kn_utils.logging import logger
from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import (
multi_tenant_storage_validators,
storage_validators,
)
from pyinfra.storage.storages.azure import get_azure_storage_from_settings
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
from pyinfra.storage.storages.storage import Storage
from pyinfra.utils.cipher import decrypt
def get_storage(settings: Dynaconf, tenant_id: str = None) -> Storage:
"""Establishes a storage connection.
If tenant_id is provided, gets storage connection information from tenant server. These connections are cached.
Otherwise, gets storage connection information from settings.
"""
logger.info("Establishing storage connection...")
if tenant_id:
logger.info(f"Using tenant storage for {tenant_id}.")
validate_settings(settings, multi_tenant_storage_validators)
return get_storage_for_tenant(
tenant_id,
settings.storage.tenant_server.endpoint,
settings.storage.tenant_server.public_key,
)
logger.info("Using default storage.")
validate_settings(settings, storage_validators)
return storage_dispatcher[settings.storage.backend](settings)
storage_dispatcher = {
"azure": get_azure_storage_from_settings,
"s3": get_s3_storage_from_settings,
}
@lru_cache(maxsize=10)
def get_storage_for_tenant(tenant: str, endpoint: str, public_key: str) -> Storage:
response = requests.get(f"{endpoint}/{tenant}").json()
maybe_azure = response.get("azureStorageConnection")
maybe_s3 = response.get("s3StorageConnection")
assert (maybe_azure or maybe_s3) and not (maybe_azure and maybe_s3), "Only one storage backend can be used."
if maybe_azure:
connection_string = decrypt(public_key, maybe_azure["connectionString"])
backend = "azure"
storage_info = {
"storage": {
"azure": {
"connection_string": connection_string,
"container": maybe_azure["containerName"],
},
}
}
elif maybe_s3:
secret = decrypt(public_key, maybe_s3["secret"])
backend = "s3"
storage_info = {
"storage": {
"s3": {
"endpoint": maybe_s3["endpoint"],
"key": maybe_s3["key"],
"secret": secret,
"region": maybe_s3["region"],
"bucket": maybe_s3["bucketName"],
},
}
}
else:
raise Exception(f"Unknown storage backend in {response}.")
storage_settings = Dynaconf()
storage_settings.update(storage_info)
storage = storage_dispatcher[backend](storage_settings)
return storage

View File

@ -1,48 +0,0 @@
from functools import lru_cache, partial
from typing import Callable
from funcy import compose
from pyinfra.config import Config
from pyinfra.storage.storage_info import get_storage_info_from_config, get_storage_from_storage_info
from pyinfra.storage.storages.interface import Storage
from pyinfra.utils.compressing import get_decompressor, get_compressor
from pyinfra.utils.encoding import get_decoder, get_encoder
def get_storage_from_config(config: Config) -> Storage:
storage_info = get_storage_info_from_config(config)
storage = get_storage_from_storage_info(storage_info)
return storage
def verify_existence(storage: Storage, bucket: str, file_name: str) -> str:
if not storage.exists(bucket, file_name):
raise FileNotFoundError(f"{file_name=} name not found on storage in {bucket=}.")
return file_name
@lru_cache(maxsize=10)
def make_downloader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable:
verify = partial(verify_existence, storage, bucket)
download = partial(storage.get_object, bucket)
decompress = get_decompressor(compression_type)
decode = get_decoder(file_type)
return compose(decode, decompress, download, verify)
@lru_cache(maxsize=10)
def make_uploader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable:
upload = partial(storage.put_object, bucket)
compress = get_compressor(compression_type)
encode = get_encoder(file_type)
def inner(file_name, file_bytes):
upload(file_name, compose(compress, encode)(file_bytes))
return inner

View File

@ -1,125 +0,0 @@
from dataclasses import dataclass
import requests
from azure.storage.blob import BlobServiceClient
from minio import Minio
from pyinfra.config import Config
from pyinfra.exception import UnknownStorageBackend
from pyinfra.storage.storages.azure import AzureStorage
from pyinfra.storage.storages.interface import Storage
from pyinfra.storage.storages.s3 import S3Storage
from pyinfra.utils.cipher import decrypt
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
@dataclass(frozen=True)
class StorageInfo:
bucket_name: str
@dataclass(frozen=True)
class AzureStorageInfo(StorageInfo):
connection_string: str
def __hash__(self):
return hash(self.connection_string)
def __eq__(self, other):
if not isinstance(other, AzureStorageInfo):
return False
return self.connection_string == other.connection_string
@dataclass(frozen=True)
class S3StorageInfo(StorageInfo):
secure: bool
endpoint: str
access_key: str
secret_key: str
region: str
def __hash__(self):
return hash((self.secure, self.endpoint, self.access_key, self.secret_key, self.region))
def __eq__(self, other):
if not isinstance(other, S3StorageInfo):
return False
return (
self.secure == other.secure
and self.endpoint == other.endpoint
and self.access_key == other.access_key
and self.secret_key == other.secret_key
and self.region == other.region
)
def get_storage_from_storage_info(storage_info: StorageInfo) -> Storage:
if isinstance(storage_info, AzureStorageInfo):
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=storage_info.connection_string))
elif isinstance(storage_info, S3StorageInfo):
return S3Storage(
Minio(
secure=storage_info.secure,
endpoint=storage_info.endpoint,
access_key=storage_info.access_key,
secret_key=storage_info.secret_key,
region=storage_info.region,
)
)
else:
raise UnknownStorageBackend()
def get_storage_info_from_endpoint(public_key: str, endpoint: str, x_tenant_id: str) -> StorageInfo:
resp = requests.get(f"{endpoint}/{x_tenant_id}").json()
maybe_azure = resp.get("azureStorageConnection")
maybe_s3 = resp.get("s3StorageConnection")
assert not (maybe_azure and maybe_s3)
if maybe_azure:
connection_string = decrypt(public_key, maybe_azure["connectionString"])
storage_info = AzureStorageInfo(
connection_string=connection_string,
bucket_name=maybe_azure["containerName"],
)
elif maybe_s3:
secure, endpoint = validate_and_parse_s3_endpoint(maybe_s3["endpoint"])
secret = decrypt(public_key, maybe_s3["secret"])
storage_info = S3StorageInfo(
secure=secure,
endpoint=endpoint,
access_key=maybe_s3["key"],
secret_key=secret,
region=maybe_s3["region"],
bucket_name=maybe_s3["bucketName"],
)
else:
raise UnknownStorageBackend()
return storage_info
def get_storage_info_from_config(config: Config) -> StorageInfo:
if config.storage_backend == "s3":
storage_info = S3StorageInfo(
secure=config.storage_secure_connection,
endpoint=config.storage_endpoint,
access_key=config.storage_key,
secret_key=config.storage_secret,
region=config.storage_region,
bucket_name=config.storage_bucket,
)
elif config.storage_backend == "azure":
storage_info = AzureStorageInfo(
connection_string=config.storage_azureconnectionstring,
bucket_name=config.storage_bucket,
)
else:
raise UnknownStorageBackend(f"Unknown storage backend '{config.storage_backend}'.")
return storage_info

View File

@ -1,55 +0,0 @@
from dataclasses import asdict
from functools import partial, lru_cache
from kn_utils.logging import logger
from typing import Tuple
from pyinfra.config import Config
from pyinfra.storage.storage_info import (
get_storage_info_from_config,
get_storage_info_from_endpoint,
StorageInfo,
get_storage_from_storage_info,
)
from pyinfra.storage.storages.interface import Storage
class StorageProvider:
def __init__(self, config: Config):
self.config = config
self.default_storage_info: StorageInfo = get_storage_info_from_config(config)
self.get_storage_info_from_tenant_id = partial(
get_storage_info_from_endpoint,
config.tenant_decryption_public_key,
config.tenant_endpoint,
)
def __call__(self, *args, **kwargs):
return self._connect(*args, **kwargs)
@lru_cache(maxsize=32)
def _connect(self, x_tenant_id=None) -> Tuple[Storage, StorageInfo]:
storage_info = self._get_storage_info(x_tenant_id)
storage_connection = get_storage_from_storage_info(storage_info)
return storage_connection, storage_info
def _get_storage_info(self, x_tenant_id=None):
if x_tenant_id:
storage_info = self.get_storage_info_from_tenant_id(x_tenant_id)
logger.debug(f"Received {storage_info.__class__.__name__} for {x_tenant_id} from endpoint.")
logger.trace(f"{asdict(storage_info)}")
else:
storage_info = self.default_storage_info
logger.debug(f"Using local default {storage_info.__class__.__name__} for {x_tenant_id}.")
logger.trace(f"{asdict(storage_info)}")
return storage_info
class StorageProviderMock(StorageProvider):
def __init__(self, storage, storage_info):
self.storage = storage
self.storage_info = storage_info
def __call__(self, *args, **kwargs):
return self.storage, self.storage_info

View File

@ -1,59 +1,67 @@
import logging
from azure.storage.blob import BlobServiceClient, ContainerClient
from itertools import repeat
from kn_utils.logging import logger
from operator import attrgetter
from azure.storage.blob import BlobServiceClient, ContainerClient
from dynaconf import Dynaconf
from kn_utils.logging import logger
from retry import retry
from pyinfra.config import Config
from pyinfra.storage.storages.interface import Storage
from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import azure_storage_validators
from pyinfra.storage.storages.storage import Storage
logging.getLogger("azure").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
class AzureStorage(Storage):
def __init__(self, client: BlobServiceClient):
def __init__(self, client: BlobServiceClient, bucket: str):
self._client: BlobServiceClient = client
self._bucket = bucket
def has_bucket(self, bucket_name):
container_client = self._client.get_container_client(bucket_name)
@property
def bucket(self):
return self._bucket
def has_bucket(self):
container_client = self._client.get_container_client(self.bucket)
return container_client.exists()
def make_bucket(self, bucket_name):
container_client = self._client.get_container_client(bucket_name)
container_client if container_client.exists() else self._client.create_container(bucket_name)
def make_bucket(self):
container_client = self._client.get_container_client(self.bucket)
container_client if container_client.exists() else self._client.create_container(self.bucket)
def __provide_container_client(self, bucket_name) -> ContainerClient:
self.make_bucket(bucket_name)
container_client = self._client.get_container_client(bucket_name)
def __provide_container_client(self) -> ContainerClient:
self.make_bucket()
container_client = self._client.get_container_client(self.bucket)
return container_client
def put_object(self, bucket_name, object_name, data):
def put_object(self, object_name, data):
logger.debug(f"Uploading '{object_name}'...")
container_client = self.__provide_container_client(bucket_name)
container_client = self.__provide_container_client()
blob_client = container_client.get_blob_client(object_name)
blob_client.upload_blob(data, overwrite=True)
def exists(self, bucket_name, object_name):
container_client = self.__provide_container_client(bucket_name)
def exists(self, object_name):
container_client = self.__provide_container_client()
blob_client = container_client.get_blob_client(object_name)
return blob_client.exists()
@retry(tries=3, delay=5, jitter=(1, 3))
def get_object(self, bucket_name, object_name):
def get_object(self, object_name):
logger.debug(f"Downloading '{object_name}'...")
try:
container_client = self.__provide_container_client(bucket_name)
container_client = self.__provide_container_client()
blob_client = container_client.get_blob_client(object_name)
blob_data = blob_client.download_blob()
return blob_data.readall()
except Exception as err:
raise Exception("Failed getting object from azure client") from err
def get_all_objects(self, bucket_name):
container_client = self.__provide_container_client(bucket_name)
def get_all_objects(self):
container_client = self.__provide_container_client()
blobs = container_client.list_blobs()
for blob in blobs:
logger.debug(f"Downloading '{blob.name}'...")
@ -62,17 +70,22 @@ class AzureStorage(Storage):
data = blob_data.readall()
yield data
def clear_bucket(self, bucket_name):
logger.debug(f"Clearing Azure container '{bucket_name}'...")
container_client = self._client.get_container_client(bucket_name)
def clear_bucket(self):
logger.debug(f"Clearing Azure container '{self.bucket}'...")
container_client = self._client.get_container_client(self.bucket)
blobs = container_client.list_blobs()
container_client.delete_blobs(*blobs)
def get_all_object_names(self, bucket_name):
container_client = self.__provide_container_client(bucket_name)
def get_all_object_names(self):
container_client = self.__provide_container_client()
blobs = container_client.list_blobs()
return zip(repeat(bucket_name), map(attrgetter("name"), blobs))
return zip(repeat(self.bucket), map(attrgetter("name"), blobs))
def get_azure_storage_from_config(config: Config):
return AzureStorage(BlobServiceClient.from_connection_string(conn_str=config.storage_azureconnectionstring))
def get_azure_storage_from_settings(settings: Dynaconf):
validate_settings(settings, azure_storage_validators)
return AzureStorage(
client=BlobServiceClient.from_connection_string(conn_str=settings.storage.azure.connection_string),
bucket=settings.storage.azure.container,
)

View File

@ -1,36 +0,0 @@
from pyinfra.storage.storages.interface import Storage
class StorageMock(Storage):
def __init__(self, data: bytes = None, file_name: str = None, bucket: str = None):
self.data = data
self.file_name = file_name
self.bucket = bucket
def make_bucket(self, bucket_name):
self.bucket = bucket_name
def has_bucket(self, bucket_name):
return self.bucket == bucket_name
def put_object(self, bucket_name, object_name, data):
self.bucket = bucket_name
self.file_name = object_name
self.data = data
def exists(self, bucket_name, object_name):
return self.bucket == bucket_name and self.file_name == object_name
def get_object(self, bucket_name, object_name):
return self.data
def get_all_objects(self, bucket_name):
raise NotImplementedError
def clear_bucket(self, bucket_name):
self.bucket = None
self.file_name = None
self.data = None
def get_all_object_names(self, bucket_name):
raise NotImplementedError

View File

@ -1,44 +1,53 @@
import io
from itertools import repeat
from operator import attrgetter
from dynaconf import Dynaconf
from kn_utils.logging import logger
from minio import Minio
from operator import attrgetter
from retry import retry
from pyinfra.config import Config
from pyinfra.storage.storages.interface import Storage
from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import s3_storage_validators
from pyinfra.storage.storages.storage import Storage
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
class S3Storage(Storage):
def __init__(self, client: Minio):
def __init__(self, client: Minio, bucket: str):
self._client = client
self._bucket = bucket
def make_bucket(self, bucket_name):
if not self.has_bucket(bucket_name):
self._client.make_bucket(bucket_name)
@property
def bucket(self):
return self._bucket
def has_bucket(self, bucket_name):
return self._client.bucket_exists(bucket_name)
def make_bucket(self):
if not self.has_bucket():
self._client.make_bucket(self.bucket)
def put_object(self, bucket_name, object_name, data):
def has_bucket(self):
return self._client.bucket_exists(self.bucket)
def put_object(self, object_name, data):
logger.debug(f"Uploading '{object_name}'...")
data = io.BytesIO(data)
self._client.put_object(bucket_name, object_name, data, length=data.getbuffer().nbytes)
self._client.put_object(self.bucket, object_name, data, length=data.getbuffer().nbytes)
def exists(self, bucket_name, object_name):
def exists(self, object_name):
try:
self._client.stat_object(bucket_name, object_name)
self._client.stat_object(self.bucket, object_name)
return True
except Exception:
return False
@retry(tries=3, delay=5, jitter=(1, 3))
def get_object(self, bucket_name, object_name):
def get_object(self, object_name):
logger.debug(f"Downloading '{object_name}'...")
response = None
try:
response = self._client.get_object(bucket_name, object_name)
response = self._client.get_object(self.bucket, object_name)
return response.data
except Exception as err:
raise Exception("Failed getting object from s3 client") from err
@ -47,29 +56,34 @@ class S3Storage(Storage):
response.close()
response.release_conn()
def get_all_objects(self, bucket_name):
for obj in self._client.list_objects(bucket_name, recursive=True):
def get_all_objects(self):
for obj in self._client.list_objects(self.bucket, recursive=True):
logger.debug(f"Downloading '{obj.object_name}'...")
yield self.get_object(bucket_name, obj.object_name)
yield self.get_object(obj.object_name)
def clear_bucket(self, bucket_name):
logger.debug(f"Clearing S3 bucket '{bucket_name}'...")
objects = self._client.list_objects(bucket_name, recursive=True)
def clear_bucket(self):
logger.debug(f"Clearing S3 bucket '{self.bucket}'...")
objects = self._client.list_objects(self.bucket, recursive=True)
for obj in objects:
self._client.remove_object(bucket_name, obj.object_name)
self._client.remove_object(self.bucket, obj.object_name)
def get_all_object_names(self, bucket_name):
objs = self._client.list_objects(bucket_name, recursive=True)
return zip(repeat(bucket_name), map(attrgetter("object_name"), objs))
def get_all_object_names(self):
objs = self._client.list_objects(self.bucket, recursive=True)
return zip(repeat(self.bucket), map(attrgetter("object_name"), objs))
def get_s3_storage_from_config(config: Config):
def get_s3_storage_from_settings(settings: Dynaconf):
validate_settings(settings, s3_storage_validators)
secure, endpoint = validate_and_parse_s3_endpoint(settings.storage.s3.endpoint)
return S3Storage(
Minio(
secure=config.storage_secure_connection,
endpoint=config.storage_endpoint,
access_key=config.storage_key,
secret_key=config.storage_secret,
region=config.storage_region,
)
client=Minio(
secure=secure,
endpoint=endpoint,
access_key=settings.storage.s3.key,
secret_key=settings.storage.s3.secret,
region=settings.storage.s3.region,
),
bucket=settings.storage.s3.bucket,
)

View File

@ -2,34 +2,39 @@ from abc import ABC, abstractmethod
class Storage(ABC):
@property
@abstractmethod
def make_bucket(self, bucket_name):
def bucket(self):
raise NotImplementedError
@abstractmethod
def has_bucket(self, bucket_name):
def make_bucket(self):
raise NotImplementedError
@abstractmethod
def put_object(self, bucket_name, object_name, data):
def has_bucket(self):
raise NotImplementedError
@abstractmethod
def exists(self, bucket_name, object_name):
def put_object(self, object_name, data):
raise NotImplementedError
@abstractmethod
def get_object(self, bucket_name, object_name):
def exists(self, object_name):
raise NotImplementedError
@abstractmethod
def get_all_objects(self, bucket_name):
def get_object(self, object_name):
raise NotImplementedError
@abstractmethod
def clear_bucket(self, bucket_name):
def get_all_objects(self):
raise NotImplementedError
@abstractmethod
def get_all_object_names(self, bucket_name):
def clear_bucket(self):
raise NotImplementedError
@abstractmethod
def get_all_object_names(self):
raise NotImplementedError

107
pyinfra/storage/utils.py Normal file
View File

@ -0,0 +1,107 @@
import gzip
import json
from typing import Union
from kn_utils.logging import logger
from pydantic import BaseModel, ValidationError
from pyinfra.storage.storages.storage import Storage
class DossierIdFileIdDownloadPayload(BaseModel):
dossierId: str
fileId: str
targetFileExtension: str
@property
def targetFilePath(self):
return f"{self.dossierId}/{self.fileId}.{self.targetFileExtension}"
class DossierIdFileIdUploadPayload(BaseModel):
dossierId: str
fileId: str
responseFileExtension: str
@property
def responseFilePath(self):
return f"{self.dossierId}/{self.fileId}.{self.responseFileExtension}"
class TargetResponseFilePathDownloadPayload(BaseModel):
targetFilePath: str
class TargetResponseFilePathUploadPayload(BaseModel):
responseFilePath: str
def download_data_as_specified_in_message(storage: Storage, raw_payload: dict) -> Union[dict, bytes]:
"""Convenience function to download a file specified in a message payload.
Supports both legacy and new payload formats.
If the content is compressed with gzip (.gz), it will be decompressed (-> bytes).
If the content is a json file, it will be decoded (-> dict).
If no file is specified in the payload or the file does not exist in storage, an exception will be raised.
In all other cases, the content will be returned as is (-> bytes).
This function can be extended in the future as needed (e.g. handling of more file types), but since further
requirements are not specified at this point in time, and it is unclear what these would entail, the code is kept
simple for now to improve readability, maintainability and avoid refactoring efforts of generic solutions that
weren't as generic as they seemed.
"""
try:
if "dossierId" in raw_payload:
payload = DossierIdFileIdDownloadPayload(**raw_payload)
else:
payload = TargetResponseFilePathDownloadPayload(**raw_payload)
except ValidationError:
raise ValueError("No download file path found in payload, nothing to download.")
if not storage.exists(payload.targetFilePath):
raise FileNotFoundError(f"File '{payload.targetFilePath}' does not exist in storage.")
data = storage.get_object(payload.targetFilePath)
data = gzip.decompress(data) if ".gz" in payload.targetFilePath else data
data = json.loads(data.decode("utf-8")) if ".json" in payload.targetFilePath else data
logger.info(f"Downloaded {payload.targetFilePath} from storage.")
return data
def upload_data_as_specified_in_message(storage: Storage, raw_payload: dict, data):
"""Convenience function to upload a file specified in a message payload. For now, only json serializable data is
supported. The storage json consists of the raw_payload, which is extended with a 'data' key, containing the
data to be uploaded.
If the content is not a json serializable object, an exception will be raised.
If the result file identifier specifies compression with gzip (.gz), it will be compressed before upload.
This function can be extended in the future as needed (e.g. if we need to upload images), but since further
requirements are not specified at this point in time, and it is unclear what these would entail, the code is kept
simple for now to improve readability, maintainability and avoid refactoring efforts of generic solutions that
weren't as generic as they seemed.
"""
try:
if "dossierId" in raw_payload:
payload = DossierIdFileIdUploadPayload(**raw_payload)
else:
payload = TargetResponseFilePathUploadPayload(**raw_payload)
except ValidationError:
raise ValueError("No upload file path found in payload, nothing to upload.")
if ".json" not in payload.responseFilePath:
raise ValueError("Only json serializable data can be uploaded.")
data = {**raw_payload, "data": data}
data = json.dumps(data).encode("utf-8")
data = gzip.compress(data) if ".gz" in payload.responseFilePath else data
storage.put_object(payload.responseFilePath, data)
logger.info(f"Uploaded {payload.responseFilePath} to storage.")

View File

@ -1,22 +0,0 @@
import gzip
from typing import Union, Callable
from funcy import identity
def get_decompressor(compression_type: Union[str, None]) -> Callable:
if not compression_type:
return identity
elif "gz" in compression_type:
return gzip.decompress
else:
raise ValueError(f"{compression_type=} is not supported.")
def get_compressor(compression_type: str) -> Callable:
if not compression_type:
return identity
elif "gz" in compression_type:
return gzip.compress
else:
raise ValueError(f"{compression_type=} is not supported.")

View File

@ -1,5 +0,0 @@
from funcy import project
def safe_project(mapping, keys) -> dict:
return project(mapping, keys) if mapping else {}

View File

@ -1,28 +0,0 @@
import json
from typing import Callable
from funcy import identity
def decode_json(data: bytes) -> dict:
return json.loads(data.decode("utf-8"))
def encode_json(data: dict) -> bytes:
return json.dumps(data).encode("utf-8")
def get_decoder(file_type: str) -> Callable:
if "json" in file_type:
return decode_json
elif "pdf" in file_type:
return identity
else:
raise ValueError(f"{file_type=} is not supported.")
def get_encoder(file_type: str) -> Callable:
if "json" in file_type:
return encode_json
else:
raise ValueError(f"{file_type=} is not supported.")

View File

@ -1,41 +0,0 @@
from collections import defaultdict
from typing import Callable
from funcy import merge
def make_file_extension_parser(file_types, compression_types):
ext2_type2ext = make_ext2_type2ext(file_types, compression_types)
ext_to_type2ext = make_ext_to_type2ext(ext2_type2ext)
def inner(path):
file_extensions = parse_file_extensions(path, ext_to_type2ext)
return file_extensions.get("file_type"), file_extensions.get("compression_type")
return inner
def make_ext2_type2ext(file_type_extensions, compression_type_extensions):
def make_ext_to_ext2type(ext_type):
return lambda ext: {ext_type: ext}
ext_to_file_type_mapper = make_ext_to_ext2type("file_type")
ext_to_compression_type_mapper = make_ext_to_ext2type("compression_type")
return defaultdict(
lambda: lambda _: {},
{
**{e: ext_to_file_type_mapper for e in file_type_extensions},
**{e: ext_to_compression_type_mapper for e in compression_type_extensions},
},
)
def make_ext_to_type2ext(ext2_type2ext):
def ext_to_type2ext(ext):
return ext2_type2ext[ext](ext)
return ext_to_type2ext
def parse_file_extensions(path, ext_to_type2ext: Callable):
return merge(*map(ext_to_type2ext, path.split(".")))

View File

@ -0,0 +1,73 @@
import json
from dynaconf import Dynaconf
from fastapi import FastAPI
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.pika import PikaInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
BatchSpanProcessor,
ConsoleSpanExporter,
SpanExporter,
SpanExportResult,
)
from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import opentelemetry_validators
class JsonSpanExporter(SpanExporter):
def __init__(self):
self.traces = []
def export(self, spans):
for span in spans:
self.traces.append(json.loads(span.to_json()))
return SpanExportResult.SUCCESS
def shutdown(self):
pass
def setup_trace(settings: Dynaconf, service_name: str = None, exporter: SpanExporter = None):
service_name = service_name or settings.tracing.opentelemetry.service_name
exporter = exporter or get_exporter(settings)
resource = Resource(attributes={"service.name": service_name})
provider = TracerProvider(resource=resource, shutdown_on_exit=True)
processor = BatchSpanProcessor(exporter)
provider.add_span_processor(processor)
# TODO: trace.set_tracer_provider produces a warning if trying to set the provider twice.
# "WARNING opentelemetry.trace:__init__.py:521 Overriding of current TracerProvider is not allowed"
# This doesn't seem to affect the functionality since we only want to use the tracer provided set in the beginning.
# We work around the log message by using the protected method with log=False.
trace._set_tracer_provider(provider, log=False)
def get_exporter(settings: Dynaconf):
validate_settings(settings, validators=opentelemetry_validators)
if settings.tracing.opentelemetry.exporter == "json":
return JsonSpanExporter()
elif settings.tracing.opentelemetry.exporter == "otlp":
return OTLPSpanExporter(endpoint=settings.tracing.opentelemetry.endpoint)
elif settings.tracing.opentelemetry.exporter == "console":
return ConsoleSpanExporter()
else:
raise ValueError(
f"Invalid OpenTelemetry exporter {settings.tracing.opentelemetry.exporter}. "
f"Valid values are 'json', 'otlp' and 'console'."
)
def instrument_pika():
PikaInstrumentor().instrument()
def instrument_app(app: FastAPI, excluded_urls: str = "/health,/ready,/prometheus"):
FastAPIInstrumentor().instrument_app(app, excluded_urls=excluded_urls)

View File

@ -0,0 +1,64 @@
from time import time
from typing import Callable, TypeVar
from dynaconf import Dynaconf
from fastapi import FastAPI
from funcy import identity
from prometheus_client import REGISTRY, CollectorRegistry, Summary, generate_latest
from starlette.responses import Response
from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import prometheus_validators
def add_prometheus_endpoint(app: FastAPI, registry: CollectorRegistry = REGISTRY) -> FastAPI:
"""Add a prometheus endpoint to the app. It is recommended to use the default global registry.
You can register your own metrics with it anywhere, and they will be scraped with this endpoint.
See https://prometheus.io/docs/concepts/metric_types/ for the different metric types.
The implementation for monitoring the processing time of a function is in the decorator below (decorate the
processing function of a service to assess the processing time of each call).
The convention for the metric name is {product_name}_{service_name}_{parameter_to_monitor}.
"""
@app.get("/prometheus")
def prometheus_metrics():
return Response(generate_latest(registry), media_type="text/plain")
return app
Decorator = TypeVar("Decorator", bound=Callable[[Callable], Callable])
def make_prometheus_processing_time_decorator_from_settings(
settings: Dynaconf,
postfix: str = "processing_time",
registry: CollectorRegistry = REGISTRY,
) -> Decorator:
"""Make a decorator for monitoring the processing time of a function. This, and other metrics should follow the
convention {product name}_{service name}_{processing step / parameter to monitor}.
"""
validate_settings(settings, validators=prometheus_validators)
processing_time_sum = Summary(
f"{settings.metrics.prometheus.prefix}_{postfix}",
"Summed up processing time per call.",
registry=registry,
)
def decorator(process_fn: Callable) -> Callable:
def inner(*args, **kwargs):
start = time()
result = process_fn(*args, **kwargs)
runtime = time() - start
processing_time_sum.observe(runtime)
return result
return inner
return decorator

View File

@ -0,0 +1,60 @@
import logging
import threading
from typing import Callable
import uvicorn
from dynaconf import Dynaconf
from fastapi import FastAPI
from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import webserver_validators
from pyinfra.utils.opentelemetry import instrument_app, setup_trace
def create_webserver_thread_from_settings(app: FastAPI, settings: Dynaconf) -> threading.Thread:
validate_settings(settings, validators=webserver_validators)
if settings.tracing.opentelemetry.enabled:
return create_webserver_thread_with_tracing(app, settings)
return create_webserver_thread(app=app, port=settings.webserver.port, host=settings.webserver.host)
def create_webserver_thread(app: FastAPI, port: int, host: str) -> threading.Thread:
"""Creates a thread that runs a FastAPI webserver. Start with thread.start(), and join with thread.join().
Note that the thread is a daemon thread, so it will be terminated when the main thread is terminated.
"""
thread = threading.Thread(target=lambda: uvicorn.run(app, port=port, host=host, log_level=logging.WARNING))
thread.daemon = True
return thread
def create_webserver_thread_with_tracing(app: FastAPI, settings: Dynaconf) -> threading.Thread:
def inner():
setup_trace(settings)
instrument_app(app)
uvicorn.run(app, port=settings.webserver.port, host=settings.webserver.host, log_level=logging.WARNING)
thread = threading.Thread(target=inner)
thread.daemon = True
return thread
HealthFunction = Callable[[], bool]
def add_health_check_endpoint(app: FastAPI, health_function: HealthFunction) -> FastAPI:
"""Add a health check endpoint to the app. The health function should return True if the service is healthy,
and False otherwise. The health function is called when the endpoint is hit.
"""
@app.get("/health")
@app.get("/ready")
def check_health():
if health_function():
return {"status": "OK"}, 200
else:
return {"status": "Service Unavailable"}, 503
return app

View File

@ -1,10 +1,9 @@
[tool.poetry]
name = "pyinfra"
version = "1.10.0"
version = "2.0.0"
description = ""
authors = ["Team Research <research@knecon.com>"]
license = "All rights reseverd"
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.10,<3.11"
@ -20,7 +19,19 @@ azure-storage-blob = "^12.13"
funcy = "^2"
pycryptodome = "^3.19"
# research shared packages
kn-utils = { version = "^0.2.4.dev112", source = "gitlab-research" }
kn-utils = { version = "^0.2.7", source = "gitlab-research" }
fastapi = "^0.109.0"
uvicorn = "^0.26.0"
# [tool.poetry.group.telemetry.dependencies]
opentelemetry-instrumentation-pika = "^0.43b0"
opentelemetry-exporter-otlp = "^1.22.0"
opentelemetry-instrumentation = "^0.43b0"
opentelemetry-api = "^1.22.0"
opentelemetry-sdk = "^1.22.0"
opentelemetry-exporter-otlp-proto-http = "^1.22.0"
opentelemetry-instrumentation-flask = "^0.43b0"
opentelemetry-instrumentation-requests = "^0.43b0"
opentelemetry-instrumentation-fastapi = "^0.43b0"
[tool.poetry.group.dev.dependencies]
pytest = "^7"
@ -29,15 +40,40 @@ black = "^23.10"
pylint = "^3"
coverage = "^7.3"
requests = "^2.31"
pre-commit = "^3.6.0"
[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-ra -q"
testpaths = ["tests", "integration"]
norecursedirs = "tests/tests_with_docker_compose"
log_cli = 1
log_cli_level = "DEBUG"
[tool.mypy]
exclude = ['.venv']
[tool.black]
line-length = 120
target-version = ["py310"]
[tool.isort]
profile = "black"
[tool.pylint.format]
max-line-length = 120
disable = [
"C0114",
"C0325",
"R0801",
"R0902",
"R0903",
"R0904",
"R0913",
"R0914",
"W0511"
]
docstring-min-length = 3
[[tool.poetry.source]]
name = "PyPI"
priority = "primary"

View File

@ -1,22 +1,17 @@
import gzip
import json
import logging
from operator import itemgetter
import pika
from kn_utils.logging import logger
from pyinfra.config import get_config
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
from pyinfra.storage.storages.s3 import get_s3_storage_from_config
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
from pyinfra.queue.manager import QueueManager
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
CONFIG = get_config()
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
settings = load_settings(local_pyinfra_root_path / "config/")
def upload_json_and_make_message_body():
bucket = CONFIG.storage_bucket
dossier_id, file_id, suffix = "dossier", "file", "json.gz"
content = {
"numberOfPages": 7,
@ -26,10 +21,10 @@ def upload_json_and_make_message_body():
object_name = f"{dossier_id}/{file_id}.{suffix}"
data = gzip.compress(json.dumps(content).encode("utf-8"))
storage = get_s3_storage_from_config(CONFIG)
if not storage.has_bucket(bucket):
storage.make_bucket(bucket)
storage.put_object(bucket, object_name, data)
storage = get_s3_storage_from_settings(settings)
if not storage.has_bucket():
storage.make_bucket()
storage.put_object(object_name, data)
message_body = {
"dossierId": dossier_id,
@ -41,31 +36,31 @@ def upload_json_and_make_message_body():
def main():
development_queue_manager = DevelopmentQueueManager(CONFIG)
development_queue_manager.clear_queues()
queue_manager = QueueManager(settings)
queue_manager.purge_queues()
message = upload_json_and_make_message_body()
development_queue_manager.publish_request(message, pika.BasicProperties(headers={"X-TENANT-ID": "redaction"}))
logger.info(f"Put {message} on {CONFIG.request_queue}")
queue_manager.publish_message_to_input_queue(message)
logger.info(f"Put {message} on {settings.rabbitmq.input_queue}.")
storage = get_s3_storage_from_config(CONFIG)
for method_frame, properties, body in development_queue_manager._channel.consume(
queue=CONFIG.response_queue, inactivity_timeout=15
storage = get_s3_storage_from_settings(settings)
for method_frame, properties, body in queue_manager.channel.consume(
queue=settings.rabbitmq.output_queue, inactivity_timeout=15
):
if not body:
break
response = json.loads(body)
logger.info(f"Received {response}")
logger.info(f"Message headers: {properties.headers}")
development_queue_manager._channel.basic_ack(method_frame.delivery_tag)
queue_manager.channel.basic_ack(method_frame.delivery_tag)
dossier_id, file_id = itemgetter("dossierId", "fileId")(response)
suffix = message["responseFileExtension"]
print(f"{dossier_id}/{file_id}.{suffix}")
result = storage.get_object(CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{suffix}")
result = storage.get_object(f"{dossier_id}/{file_id}.{suffix}")
result = json.loads(gzip.decompress(result))
logger.info(f"Contents of result on storage: {result}")
development_queue_manager.close_channel()
queue_manager.stop_consuming()
if __name__ == "__main__":

View File

@ -1,25 +1,18 @@
import logging
import time
from pyinfra.config import get_config
from pyinfra.payload_processing.processor import make_payload_processor
from pyinfra.queue.queue_manager import QueueManager
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
from pyinfra.config.loader import load_settings, parse_settings_path
from pyinfra.examples import start_standard_queue_consumer
from pyinfra.queue.callback import make_download_process_upload_callback
def json_processor_mock(data: dict):
def processor_mock(_data: dict, _message: dict) -> dict:
time.sleep(5)
return [{"result1": "result1"}, {"result2": "result2"}]
def main():
logger.info("Start consuming...")
queue_manager = QueueManager(get_config())
queue_manager.start_consuming(make_payload_processor(json_processor_mock))
return {"result1": "result1"}
if __name__ == "__main__":
main()
arguments = parse_settings_path()
settings = load_settings(arguments.settings_path)
callback = make_download_process_upload_callback(processor_mock, settings)
start_standard_queue_consumer(callback, settings)

View File

@ -1,142 +1,46 @@
import gzip
import json
import pytest
from pyinfra.payload_processing.payload import LegacyQueueMessagePayload, QueueMessagePayload
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
from pyinfra.queue.manager import QueueManager
from pyinfra.storage.connection import get_storage
@pytest.fixture(scope="session")
def settings():
return load_settings(local_pyinfra_root_path / "config/")
@pytest.fixture(scope="class")
def storage(storage_backend, settings):
settings.storage.backend = storage_backend
storage = get_storage(settings)
storage.make_bucket()
yield storage
storage.clear_bucket()
@pytest.fixture(scope="session")
def queue_manager(settings):
settings.rabbitmq_heartbeat = 10
settings.connection_sleep = 5
queue_manager = QueueManager(settings)
yield queue_manager
@pytest.fixture
def legacy_payload(x_tenant_id, optional_processing_kwargs):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
optional_processing_kwargs = optional_processing_kwargs or {}
return {
"dossierId": "test",
"fileId": "test",
"targetFileExtension": "target.json.gz",
"responseFileExtension": "response.json.gz",
**x_tenant_entry,
**optional_processing_kwargs,
}
@pytest.fixture
def target_file_path():
return "test/test.target.json.gz"
@pytest.fixture
def response_file_path():
return "test/test.response.json.gz"
@pytest.fixture
def payload(x_tenant_id, optional_processing_kwargs, target_file_path, response_file_path):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
optional_processing_kwargs = optional_processing_kwargs or {}
return {
"targetFilePath": target_file_path,
"responseFilePath": response_file_path,
**x_tenant_entry,
**optional_processing_kwargs,
}
@pytest.fixture
def legacy_queue_response_payload(x_tenant_id, optional_processing_kwargs):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
optional_processing_kwargs = optional_processing_kwargs or {}
return {
"dossierId": "test",
"fileId": "test",
**x_tenant_entry,
**optional_processing_kwargs,
}
@pytest.fixture
def queue_response_payload(x_tenant_id, optional_processing_kwargs, target_file_path, response_file_path):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
optional_processing_kwargs = optional_processing_kwargs or {}
return {
"targetFilePath": target_file_path,
"responseFilePath": response_file_path,
**x_tenant_entry,
**optional_processing_kwargs,
}
@pytest.fixture
def legacy_storage_payload(x_tenant_id, optional_processing_kwargs, processing_result_json):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
optional_processing_kwargs = optional_processing_kwargs or {}
return {
"dossierId": "test",
"fileId": "test",
"targetFileExtension": "target.json.gz",
"responseFileExtension": "response.json.gz",
**x_tenant_entry,
**optional_processing_kwargs,
"data": processing_result_json,
}
@pytest.fixture
def storage_payload(x_tenant_id, optional_processing_kwargs, processing_result_json, target_file_path, response_file_path):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
optional_processing_kwargs = optional_processing_kwargs or {}
return {
"targetFilePath": target_file_path,
"responseFilePath": response_file_path,
**x_tenant_entry,
**optional_processing_kwargs,
"data": processing_result_json,
}
@pytest.fixture
def legacy_parsed_payload(
x_tenant_id, optional_processing_kwargs, target_file_path, response_file_path
) -> LegacyQueueMessagePayload:
return LegacyQueueMessagePayload(
dossier_id="test",
file_id="test",
x_tenant_id=x_tenant_id,
target_file_extension="target.json.gz",
response_file_extension="response.json.gz",
target_file_type="json",
target_compression_type="gz",
response_file_type="json",
response_compression_type="gz",
target_file_path=target_file_path,
response_file_path=response_file_path,
processing_kwargs=optional_processing_kwargs or {},
def input_message():
return json.dumps(
{
"targetFilePath": "test/target.json.gz",
"responseFilePath": "test/response.json.gz",
}
)
@pytest.fixture
def parsed_payload(
x_tenant_id, optional_processing_kwargs, target_file_path, response_file_path
) -> QueueMessagePayload:
return QueueMessagePayload(
x_tenant_id=x_tenant_id,
target_file_type="json",
target_compression_type="gz",
response_file_type="json",
response_compression_type="gz",
target_file_path=target_file_path,
response_file_path=response_file_path,
processing_kwargs=optional_processing_kwargs or {},
)
@pytest.fixture
def target_json_file() -> bytes:
data = {"target": "test"}
enc_data = json.dumps(data).encode("utf-8")
compr_data = gzip.compress(enc_data)
return compr_data
@pytest.fixture
def processing_result_json() -> dict:
return {"response": "test"}
def stop_message():
return "STOP"

View File

@ -8,7 +8,7 @@ services:
- MINIO_ROOT_PASSWORD=password
- MINIO_ROOT_USER=root
volumes:
- ./data/minio_store:/data
- /tmp/minio_store:/data
command: server /data
network_mode: "bridge"
rabbitmq:

View File

@ -1,123 +0,0 @@
import json
import logging
import time
from multiprocessing import Process
import pika
import pika.exceptions
import pytest
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
from pyinfra.queue.queue_manager import QueueManager
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@pytest.fixture(scope="session")
def development_queue_manager(test_queue_config):
test_queue_config.rabbitmq_heartbeat = 7200
development_queue_manager = DevelopmentQueueManager(test_queue_config)
yield development_queue_manager
logger.info("Tearing down development queue manager...")
try:
development_queue_manager.close_channel()
except pika.exceptions.ConnectionClosedByBroker:
pass
@pytest.fixture(scope="session")
def payload_processing_time(test_queue_config, offset=5):
# FIXME: this implicitly tests the heartbeat when running the end-to-end test. There should be another way to test
# this explicitly.
return test_queue_config.rabbitmq_heartbeat + offset
@pytest.fixture(scope="session")
def payload_processor(response_payload, payload_processing_time, payload_processor_type):
def process(payload):
time.sleep(payload_processing_time)
return response_payload
def process_with_failure(payload):
raise MemoryError
if payload_processor_type == "mock":
return process
elif payload_processor_type == "failing":
return process_with_failure
@pytest.fixture(scope="session", autouse=True)
def start_queue_consumer(test_queue_config, payload_processor, sleep_seconds=5):
def consume_queue():
queue_manager.start_consuming(payload_processor)
queue_manager = QueueManager(test_queue_config)
p = Process(target=consume_queue)
p.start()
logger.info(f"Setting up consumer, waiting for {sleep_seconds}...")
time.sleep(sleep_seconds)
yield
logger.info("Tearing down consumer...")
p.terminate()
@pytest.fixture
def message_properties(message_headers):
if not message_headers:
return pika.BasicProperties(headers=None)
elif message_headers == "X-TENANT-ID":
return pika.BasicProperties(headers={"X-TENANT-ID": "redaction"})
else:
raise Exception(f"Invalid {message_headers=}.")
@pytest.mark.parametrize("x_tenant_id", [None])
class TestQueueManager:
# FIXME: All tests here are wonky. This is due to the implementation of running the process-blocking queue_manager
# in a subprocess. It is then very hard to interact directly with the subprocess. If you have a better idea, please
# refactor; the tests here are insufficient to ensure the functionality of the queue manager!
@pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session")
def test_message_processing_does_not_block_heartbeat(
self, development_queue_manager, payload, response_payload, payload_processing_time
):
development_queue_manager.clear_queues()
development_queue_manager.publish_request(payload)
time.sleep(payload_processing_time + 10)
_, _, body = development_queue_manager.get_response()
result = json.loads(body)
assert result == response_payload
@pytest.mark.parametrize("message_headers", [None, "X-TENANT-ID"])
@pytest.mark.parametrize("payload_processor_type", ["mock"], scope="session")
def test_queue_manager_forwards_message_headers(
self,
development_queue_manager,
payload,
response_payload,
payload_processing_time,
message_properties,
):
development_queue_manager.clear_queues()
development_queue_manager.publish_request(payload, message_properties)
time.sleep(payload_processing_time + 10)
_, properties, _ = development_queue_manager.get_response()
assert properties.headers == message_properties.headers
# FIXME: It is not possible to test the behavior of the queue manager directly, since it is running in a separate
# process. You require logging to see if the exception is handled correctly. Hence, this test is only useful for
# development, but insufficient to guarantee the correct behavior.
@pytest.mark.parametrize("payload_processor_type", ["failing"], scope="session")
def test_failed_message_processing_is_handled(
self,
development_queue_manager,
payload,
response_payload,
payload_processing_time,
):
development_queue_manager.clear_queues()
development_queue_manager.publish_request(payload)
time.sleep(payload_processing_time + 10)
_, _, body = development_queue_manager.get_response()
assert not body

View File

@ -1,57 +0,0 @@
import logging
import pytest
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="session")
@pytest.mark.parametrize("bucket_name", ["testbucket"], scope="session")
@pytest.mark.parametrize("monitoring_enabled", [False], scope="session")
class TestStorage:
def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
data_received = storage.get_all_objects(bucket_name)
assert not {*data_received}
def test_getting_object_put_in_bucket_is_object(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, "file", b"content")
data_received = storage.get_object(bucket_name, "file")
assert b"content" == data_received
def test_object_put_in_bucket_exists_on_storage(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, "file", b"content")
assert storage.exists(bucket_name, "file")
def test_getting_nested_object_put_in_bucket_is_nested_object(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, "folder/file", b"content")
data_received = storage.get_object(bucket_name, "folder/file")
assert b"content" == data_received
def test_getting_objects_put_in_bucket_are_objects(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, "file1", b"content 1")
storage.put_object(bucket_name, "folder/file2", b"content 2")
data_received = storage.get_all_objects(bucket_name)
assert {b"content 1", b"content 2"} == {*data_received}
def test_make_bucket_produces_bucket(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.make_bucket(bucket_name)
assert storage.has_bucket(bucket_name)
def test_listing_bucket_files_yields_all_files_in_bucket(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, "file1", b"content 1")
storage.put_object(bucket_name, "file2", b"content 2")
full_names_received = storage.get_all_object_names(bucket_name)
assert {(bucket_name, "file1"), (bucket_name, "file2")} == {*full_names_received}
def test_data_loading_failure_raised_if_object_not_present(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
with pytest.raises(Exception):
storage.get_object(bucket_name, "folder/file")

View File

@ -0,0 +1,55 @@
import os
from pathlib import Path
import pytest
from dynaconf import Validator
from pyinfra.config.loader import load_settings, local_pyinfra_root_path, normalize_to_settings_files
from pyinfra.config.validators import webserver_validators
@pytest.fixture
def test_validators():
return [
Validator("test.value.int", must_exist=True, is_type_of=int),
Validator("test.value.str", must_exist=True, is_type_of=str),
]
class TestConfig:
def test_config_validation(self):
os.environ["WEBSERVER__HOST"] = "localhost"
os.environ["WEBSERVER__PORT"] = "8080"
validators = webserver_validators
test_settings = load_settings(root_path=local_pyinfra_root_path, validators=validators)
assert test_settings.webserver.host == "localhost"
def test_env_into_correct_type_conversion(self, test_validators):
os.environ["TEST__VALUE__INT"] = "1"
os.environ["TEST__VALUE__STR"] = "test"
test_settings = load_settings(root_path=local_pyinfra_root_path, validators=test_validators)
assert test_settings.test.value.int == 1
assert test_settings.test.value.str == "test"
@pytest.mark.parametrize(
"settings_path,expected_file_paths",
[
(None, []),
("config", [f"{local_pyinfra_root_path}/config/settings.toml"]),
("config/settings.toml", [f"{local_pyinfra_root_path}/config/settings.toml"]),
(f"{local_pyinfra_root_path}/config", [f"{local_pyinfra_root_path}/config/settings.toml"]),
],
)
def test_normalize_settings_files(self, settings_path, expected_file_paths):
files = normalize_to_settings_files(settings_path, local_pyinfra_root_path)
print(files)
assert len(files) == len(expected_file_paths)
for path, expected in zip(files, expected_file_paths):
assert path == Path(expected).absolute()

View File

@ -4,7 +4,7 @@ from kn_utils.logging import logger
def test_necessary_log_levels_are_supported_by_kn_utils():
logger.setLevel("TRACE")
logger.trace("trace")
logger.debug("debug")
logger.info("info")
@ -13,6 +13,7 @@ def test_necessary_log_levels_are_supported_by_kn_utils():
logger.exception("exception", exc_info="this is an exception")
logger.error("error", exc_info="this is an error")
def test_setlevel_warn():
logger.setLevel("WARN")
logger.warning("warn")

View File

@ -0,0 +1,51 @@
from time import sleep
import pytest
from pyinfra.utils.opentelemetry import get_exporter, instrument_pika, setup_trace
@pytest.fixture(scope="session")
def exporter(settings):
settings.tracing.opentelemetry.exporter = "json"
return get_exporter(settings)
class TestOpenTelemetry:
def test_queue_messages_are_traced(self, queue_manager, input_message, stop_message, settings, exporter):
setup_trace(settings, exporter=exporter)
instrument_pika()
queue_manager.purge_queues()
queue_manager.publish_message_to_input_queue(input_message)
queue_manager.publish_message_to_input_queue(stop_message)
def callback(_):
sleep(2)
return {"flat": "earth"}
queue_manager.start_consuming(callback)
for exported_trace in exporter.traces:
assert (
exported_trace["resource"]["attributes"]["service.name"] == settings.tracing.opentelemetry.service_name
)
# def test_webserver_requests_are_traced(self, settings):
# settings.tracing.opentelemetry.exporter = "console"
# settings.tracing.opentelemetry.enabled = True
#
# app = FastAPI()
#
# @app.get("/test")
# def test():
# return {"test": "test"}
#
# thread = create_webserver_thread_from_settings(app, settings)
# thread.start()
# sleep(1)
#
# requests.get(f"http://{settings.webserver.host}:{settings.webserver.port}/test")
#
# thread.join(timeout=1)

View File

@ -0,0 +1,55 @@
import re
from time import sleep
import pytest
import requests
from fastapi import FastAPI
from pyinfra.webserver.prometheus import (
add_prometheus_endpoint,
make_prometheus_processing_time_decorator_from_settings,
)
from pyinfra.webserver.utils import create_webserver_thread_from_settings
@pytest.fixture(scope="class")
def app_with_prometheus_endpoint(settings):
app = FastAPI()
app = add_prometheus_endpoint(app)
thread = create_webserver_thread_from_settings(app, settings)
thread.daemon = True
thread.start()
sleep(1)
yield
thread.join(timeout=1)
@pytest.fixture
def monitored_function(settings):
@make_prometheus_processing_time_decorator_from_settings(settings)
def process(*args, **kwargs):
sleep(0.5)
return process
class TestPrometheusMonitor:
def test_prometheus_endpoint_is_available(self, app_with_prometheus_endpoint, settings):
resp = requests.get(f"http://{settings.webserver.host}:{settings.webserver.port}/prometheus")
assert resp.status_code == 200
def test_processing_with_a_monitored_fn_increases_parameter_counter(
self, app_with_prometheus_endpoint, monitored_function, settings
):
pattern = re.compile(rf".*{settings.metrics.prometheus.prefix}_processing_time_count (\d\.\d).*")
resp = requests.get(f"http://{settings.webserver.host}:{settings.webserver.port}/prometheus")
assert pattern.search(resp.text).group(1) == "0.0"
monitored_function()
resp = requests.get(f"http://{settings.webserver.host}:{settings.webserver.port}/prometheus")
assert pattern.search(resp.text).group(1) == "1.0"
monitored_function()
resp = requests.get(f"http://{settings.webserver.host}:{settings.webserver.port}/prometheus")
assert pattern.search(resp.text).group(1) == "2.0"

View File

@ -0,0 +1,77 @@
import json
from sys import stdout
from time import sleep
import pika
import pytest
from kn_utils.logging import logger
logger.remove()
logger.add(sink=stdout, level="DEBUG")
def make_callback(process_time):
def callback(x):
sleep(process_time)
return {"status": "success"}
return callback
class TestQueueManager:
def test_processing_of_several_messages(self, queue_manager, input_message, stop_message):
queue_manager.purge_queues()
for _ in range(2):
queue_manager.publish_message_to_input_queue(input_message)
queue_manager.publish_message_to_input_queue(stop_message)
callback = make_callback(1)
queue_manager.start_consuming(callback)
for _ in range(2):
response = queue_manager.get_message_from_output_queue()
assert response is not None
assert json.loads(response[2].decode()) == {"status": "success"}
def test_all_headers_beginning_with_x_are_forwarded(self, queue_manager, input_message, stop_message):
queue_manager.purge_queues()
properties = pika.BasicProperties(
headers={
"X-TENANT-ID": "redaction",
"X-OTHER-HEADER": "other-header-value",
"x-tenant_id": "tenant-id-value",
"x_should_not_be_forwarded": "should-not-be-forwarded-value",
}
)
queue_manager.publish_message_to_input_queue(input_message, properties=properties)
queue_manager.publish_message_to_input_queue(stop_message)
callback = make_callback(0.2)
queue_manager.start_consuming(callback)
response = queue_manager.get_message_from_output_queue()
assert json.loads(response[2].decode()) == {"status": "success"}
assert response[1].headers["X-TENANT-ID"] == "redaction"
assert response[1].headers["X-OTHER-HEADER"] == "other-header-value"
assert response[1].headers["x-tenant_id"] == "tenant-id-value"
assert "x_should_not_be_forwarded" not in response[1].headers
def test_message_processing_does_not_block_heartbeat(self, queue_manager, input_message, stop_message):
queue_manager.purge_queues()
queue_manager.publish_message_to_input_queue(input_message)
queue_manager.publish_message_to_input_queue(stop_message)
callback = make_callback(15)
queue_manager.start_consuming(callback)
response = queue_manager.get_message_from_output_queue()
assert json.loads(response[2].decode()) == {"status": "success"}

View File

@ -0,0 +1,154 @@
import gzip
import json
from time import sleep
import pytest
from fastapi import FastAPI
from pyinfra.storage.connection import get_storage_for_tenant
from pyinfra.storage.utils import (
download_data_as_specified_in_message,
upload_data_as_specified_in_message,
)
from pyinfra.utils.cipher import encrypt
from pyinfra.webserver.utils import create_webserver_thread
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
class TestStorage:
def test_clearing_bucket_yields_empty_bucket(self, storage):
storage.clear_bucket()
data_received = storage.get_all_objects()
assert not {*data_received}
def test_getting_object_put_in_bucket_is_object(self, storage):
storage.clear_bucket()
storage.put_object("file", b"content")
data_received = storage.get_object("file")
assert b"content" == data_received
def test_object_put_in_bucket_exists_on_storage(self, storage):
storage.clear_bucket()
storage.put_object("file", b"content")
assert storage.exists("file")
def test_getting_nested_object_put_in_bucket_is_nested_object(self, storage):
storage.clear_bucket()
storage.put_object("folder/file", b"content")
data_received = storage.get_object("folder/file")
assert b"content" == data_received
def test_getting_objects_put_in_bucket_are_objects(self, storage):
storage.clear_bucket()
storage.put_object("file1", b"content 1")
storage.put_object("folder/file2", b"content 2")
data_received = storage.get_all_objects()
assert {b"content 1", b"content 2"} == {*data_received}
def test_make_bucket_produces_bucket(self, storage):
storage.clear_bucket()
storage.make_bucket()
assert storage.has_bucket()
def test_listing_bucket_files_yields_all_files_in_bucket(self, storage):
storage.clear_bucket()
storage.put_object("file1", b"content 1")
storage.put_object("file2", b"content 2")
full_names_received = storage.get_all_object_names()
assert {(storage.bucket, "file1"), (storage.bucket, "file2")} == {*full_names_received}
def test_data_loading_failure_raised_if_object_not_present(self, storage):
storage.clear_bucket()
with pytest.raises(Exception):
storage.get_object("folder/file")
@pytest.fixture(scope="class")
def tenant_server_mock(settings, tenant_server_host, tenant_server_port):
app = FastAPI()
@app.get("/azure_tenant")
def get_azure_storage_info():
return {
"azureStorageConnection": {
"connectionString": encrypt(
settings.storage.tenant_server.public_key, settings.storage.azure.connection_string
),
"containerName": settings.storage.azure.container,
}
}
@app.get("/s3_tenant")
def get_s3_storage_info():
return {
"s3StorageConnection": {
"endpoint": settings.storage.s3.endpoint,
"key": settings.storage.s3.key,
"secret": encrypt(settings.storage.tenant_server.public_key, settings.storage.s3.secret),
"region": settings.storage.s3.region,
"bucketName": settings.storage.s3.bucket,
}
}
thread = create_webserver_thread(app, tenant_server_port, tenant_server_host)
thread.daemon = True
thread.start()
sleep(1)
yield
thread.join(timeout=1)
@pytest.mark.parametrize("tenant_id", ["azure_tenant", "s3_tenant"], scope="class")
@pytest.mark.parametrize("tenant_server_host", ["localhost"], scope="class")
@pytest.mark.parametrize("tenant_server_port", [8000], scope="class")
class TestMultiTenantStorage:
def test_storage_connection_from_tenant_id(
self, tenant_id, tenant_server_mock, settings, tenant_server_host, tenant_server_port
):
settings["storage"]["tenant_server"]["endpoint"] = f"http://{tenant_server_host}:{tenant_server_port}"
storage = get_storage_for_tenant(
tenant_id,
settings["storage"]["tenant_server"]["endpoint"],
settings["storage"]["tenant_server"]["public_key"],
)
storage.put_object("file", b"content")
data_received = storage.get_object("file")
assert b"content" == data_received
@pytest.fixture
def payload(payload_type):
if payload_type == "target_response_file_path":
return {
"targetFilePath": "test/file.target.json.gz",
"responseFilePath": "test/file.response.json.gz",
}
elif payload_type == "dossier_id_file_id":
return {
"dossierId": "test",
"fileId": "file",
"targetFileExtension": "target.json.gz",
"responseFileExtension": "response.json.gz",
}
@pytest.mark.parametrize("payload_type", ["target_response_file_path", "dossier_id_file_id"], scope="class")
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
class TestDownloadAndUploadFromMessage:
def test_download_and_upload_from_message(self, storage, payload):
storage.clear_bucket()
input_data = {"data": "success"}
storage.put_object("test/file.target.json.gz", gzip.compress(json.dumps(input_data).encode()))
data = download_data_as_specified_in_message(storage, payload)
assert data == input_data
upload_data_as_specified_in_message(storage, payload, input_data)
data = json.loads(gzip.decompress(storage.get_object("test/file.response.json.gz")).decode())
assert data == {**payload, "data": input_data}

View File

@ -1,32 +0,0 @@
import pytest
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
@pytest.fixture
def file_extension_parser(file_types, compression_types):
return make_file_extension_parser(file_types, compression_types)
@pytest.mark.parametrize(
"file_path,file_types,compression_types,expected_file_extension,expected_compression_extension",
[
("test.txt", ["txt"], ["gz"], "txt", None),
("test.txt.gz", ["txt"], ["gz"], "txt", "gz"),
("test.txt.gz", [], [], None, None),
("test.txt.gz", ["txt"], [], "txt", None),
("test.txt.gz", [], ["gz"], None, "gz"),
("test", ["txt"], ["gz"], None, None),
],
)
def test_file_extension_parsing(
file_extension_parser,
file_path,
file_types,
compression_types,
expected_file_extension,
expected_compression_extension,
):
file_extension, compression_extension = file_extension_parser(file_path)
assert file_extension == expected_file_extension
assert compression_extension == expected_compression_extension

View File

@ -1,44 +0,0 @@
import re
import time
import pytest
import requests
from pyinfra.payload_processing.monitor import PrometheusMonitor
@pytest.fixture(scope="class")
def monitored_mock_function(metric_prefix, host, port):
def process(data=None):
time.sleep(2)
return ["result1", "result2", "result3"]
monitor = PrometheusMonitor(metric_prefix, host, port)
return monitor(process)
@pytest.fixture
def metric_endpoint(host, port):
return f"http://{host}:{port}/prometheus"
@pytest.mark.parametrize("metric_prefix, host, port", [("test", "0.0.0.0", 8000)], scope="class")
class TestPrometheusMonitor:
def test_prometheus_endpoint_is_available(self, metric_endpoint, monitored_mock_function):
resp = requests.get(metric_endpoint)
assert resp.status_code == 200
def test_processing_with_a_monitored_fn_increases_parameter_counter(
self,
metric_endpoint,
metric_prefix,
monitored_mock_function,
):
monitored_mock_function(data=None)
resp = requests.get(metric_endpoint)
pattern = re.compile(rf".*{metric_prefix}_processing_time_count (\d\.\d).*")
assert pattern.search(resp.text).group(1) == "1.0"
monitored_mock_function(data=None)
resp = requests.get(metric_endpoint)
assert pattern.search(resp.text).group(1) == "2.0"

View File

@ -1,48 +0,0 @@
import pytest
from pyinfra.config import get_config
from pyinfra.payload_processing.payload import (
get_queue_message_payload_parser,
format_to_queue_message_response_body,
format_service_processing_result_for_storage,
)
@pytest.fixture
def payload_parser():
config = get_config()
return get_queue_message_payload_parser(config)
@pytest.mark.parametrize("x_tenant_id", [None, "klaus"])
@pytest.mark.parametrize("optional_processing_kwargs", [{}, {"operation": "test"}])
class TestPayloadParsing:
def test_legacy_payload_parsing(self, payload_parser, legacy_payload, legacy_parsed_payload):
parsed_payload = payload_parser(legacy_payload)
assert parsed_payload == legacy_parsed_payload
def test_payload_parsing(self, payload_parser, payload, parsed_payload):
parsed_payload = payload_parser(payload)
assert parsed_payload == parsed_payload
@pytest.mark.parametrize("x_tenant_id", [None, "klaus"])
@pytest.mark.parametrize("optional_processing_kwargs", [{}, {"operation": "test"}])
class TestPayloadFormatting:
def test_legacy_payload_formatting_for_response(self, legacy_parsed_payload, legacy_queue_response_payload):
formatted_payload = format_to_queue_message_response_body(legacy_parsed_payload)
assert formatted_payload == legacy_queue_response_payload
def test_payload_formatting_for_response(self, parsed_payload, queue_response_payload):
formatted_payload = format_to_queue_message_response_body(parsed_payload)
assert formatted_payload == queue_response_payload
def test_legacy_payload_formatting_for_storage(
self, legacy_parsed_payload, processing_result_json, legacy_storage_payload
):
formatted_payload = format_service_processing_result_for_storage(legacy_parsed_payload, processing_result_json)
assert formatted_payload == legacy_storage_payload
def test_payload_formatting_for_storage(self, parsed_payload, processing_result_json, storage_payload):
formatted_payload = format_service_processing_result_for_storage(parsed_payload, processing_result_json)
assert formatted_payload == storage_payload

View File

@ -1,81 +0,0 @@
import gzip
import json
import pytest
from pyinfra.config import get_config
from pyinfra.payload_processing.payload import get_queue_message_payload_parser
from pyinfra.payload_processing.processor import PayloadProcessor
from pyinfra.storage.storage_info import StorageInfo
from pyinfra.storage.storage_provider import StorageProviderMock
from pyinfra.storage.storages.mock import StorageMock
@pytest.fixture
def bucket_name():
return "test_bucket"
@pytest.fixture
def storage_mock(target_json_file, target_file_path, bucket_name):
storage = StorageMock(target_json_file, target_file_path, bucket_name)
return storage
@pytest.fixture
def storage_info_mock(bucket_name):
return StorageInfo(bucket_name)
@pytest.fixture
def data_processor_mock(processing_result_json):
def inner(data, **kwargs):
return processing_result_json
return inner
@pytest.fixture
def payload_processor(storage_mock, storage_info_mock, data_processor_mock):
storage_provider = StorageProviderMock(storage_mock, storage_info_mock)
payload_parser = get_queue_message_payload_parser(get_config())
return PayloadProcessor(storage_provider, payload_parser, data_processor_mock)
@pytest.mark.parametrize("x_tenant_id", [None, "klaus"])
@pytest.mark.parametrize("optional_processing_kwargs", [{}, {"operation": "test"}])
class TestPayloadProcessor:
def test_payload_processor_yields_correct_response_and_uploads_result_for_legacy_message(
self,
payload_processor,
storage_mock,
bucket_name,
response_file_path,
legacy_payload,
legacy_queue_response_payload,
legacy_storage_payload,
):
response = payload_processor(legacy_payload)
assert response == legacy_queue_response_payload
data_stored = storage_mock.get_object(bucket_name, response_file_path)
assert json.loads(gzip.decompress(data_stored).decode()) == legacy_storage_payload
def test_payload_processor_yields_correct_response_and_uploads_result(
self,
payload_processor,
storage_mock,
bucket_name,
response_file_path,
payload,
queue_response_payload,
storage_payload,
):
response = payload_processor(payload)
assert response == queue_response_payload
data_stored = storage_mock.get_object(bucket_name, response_file_path)
assert json.loads(gzip.decompress(data_stored).decode()) == storage_payload