Compare commits
4 Commits
master
...
basf_ner_p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
86840ecaca | ||
|
|
225aa1f4ad | ||
|
|
8d2a9240d1 | ||
|
|
9f39be7077 |
106
.dockerignore
Normal file
106
.dockerignore
Normal file
@ -0,0 +1,106 @@
|
||||
data
|
||||
/build_venv/
|
||||
/.venv/
|
||||
/misc/
|
||||
/incl/image_service/test/
|
||||
/scratch/
|
||||
/bamboo-specs/
|
||||
README.md
|
||||
Dockerfile
|
||||
*idea
|
||||
*misc
|
||||
*egg-innfo
|
||||
*pycache*
|
||||
|
||||
# Git
|
||||
.git
|
||||
.gitignore
|
||||
|
||||
# CI
|
||||
.codeclimate.yml
|
||||
.travis.yml
|
||||
.taskcluster.yml
|
||||
|
||||
# Docker
|
||||
.docker
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*/__pycache__/
|
||||
*/*/__pycache__/
|
||||
*/*/*/__pycache__/
|
||||
*.py[cod]
|
||||
*/*.py[cod]
|
||||
*/*/*.py[cod]
|
||||
*/*/*/*.py[cod]
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
*.egg-info/**
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Virtual environment
|
||||
.env/
|
||||
.venv/
|
||||
#venv/
|
||||
|
||||
# PyCharm
|
||||
.idea
|
||||
|
||||
# Python mode for VIM
|
||||
.ropeproject
|
||||
*/.ropeproject
|
||||
*/*/.ropeproject
|
||||
*/*/*/.ropeproject
|
||||
|
||||
# Vim swap files
|
||||
*.swp
|
||||
*/*.swp
|
||||
*/*/*.swp
|
||||
*/*/*/*.swp
|
||||
2
.dvc/.gitignore
vendored
2
.dvc/.gitignore
vendored
@ -1,2 +0,0 @@
|
||||
/config.local
|
||||
/cache
|
||||
@ -1,5 +0,0 @@
|
||||
[core]
|
||||
remote = azure
|
||||
['remote "azure"']
|
||||
url = azure://pyinfra-dvc
|
||||
connection_string =
|
||||
@ -1,3 +0,0 @@
|
||||
# Add patterns of files dvc should ignore, which could improve
|
||||
# the performance. Learn more at
|
||||
# https://dvc.org/doc/user-guide/dvcignore
|
||||
57
.gitignore
vendored
57
.gitignore
vendored
@ -1,53 +1,10 @@
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
.DS_Store
|
||||
|
||||
# Project folders
|
||||
*.vscode/
|
||||
.idea
|
||||
*_app
|
||||
*pytest_cache
|
||||
*joblib
|
||||
*tmp
|
||||
*profiling
|
||||
*logs
|
||||
*docker
|
||||
*drivers
|
||||
*bamboo-specs/target
|
||||
.coverage
|
||||
data
|
||||
__pycache__
|
||||
data/
|
||||
build_venv
|
||||
reports
|
||||
|
||||
# Python specific files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.ipynb
|
||||
*.ipynb_checkpoints
|
||||
|
||||
# file extensions
|
||||
*.log
|
||||
*.csv
|
||||
*.pkl
|
||||
*.profile
|
||||
*.cbm
|
||||
*.egg-info
|
||||
|
||||
# temp files
|
||||
*.swp
|
||||
*~
|
||||
*.un~
|
||||
|
||||
# keep files
|
||||
!notebooks/*.ipynb
|
||||
|
||||
# keep folders
|
||||
!secrets
|
||||
!data/*
|
||||
!drivers
|
||||
|
||||
# ignore files
|
||||
bamboo.yml
|
||||
pyinfra.egg-info
|
||||
bamboo-specs/target
|
||||
.pytest_cache
|
||||
/.coverage
|
||||
.idea
|
||||
@ -1,23 +0,0 @@
|
||||
# CI for services, check gitlab repo for python package CI
|
||||
include:
|
||||
- project: "Gitlab/gitlab"
|
||||
ref: main
|
||||
file: "/ci-templates/research/python_pkg-test-build-release.gitlab-ci.yml"
|
||||
|
||||
# set project variables here
|
||||
variables:
|
||||
NEXUS_PROJECT_DIR: research # subfolder in Nexus docker-gin where your container will be stored
|
||||
IMAGENAME: $CI_PROJECT_NAME # if the project URL is gitlab.example.com/group-name/project-1, CI_PROJECT_NAME is project-1
|
||||
REPORTS_DIR: reports
|
||||
FF_USE_FASTZIP: "true" # enable fastzip - a faster zip implementation that also supports level configuration.
|
||||
ARTIFACT_COMPRESSION_LEVEL: default # can also be set to fastest, fast, slow and slowest. If just enabling fastzip is not enough try setting this to fastest or fast.
|
||||
CACHE_COMPRESSION_LEVEL: default # same as above, but for caches
|
||||
# TRANSFER_METER_FREQUENCY: 5s # will display transfer progress every 5 seconds for artifacts and remote caches. For debugging purposes.
|
||||
|
||||
|
||||
############
|
||||
# UNIT TESTS
|
||||
unit-tests:
|
||||
variables:
|
||||
###### UPDATE/EDIT ######
|
||||
UNIT_TEST_DIR: "tests/unit_test"
|
||||
@ -1,55 +0,0 @@
|
||||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
exclude: ^(docs/|notebooks/|data/|src/configs/|tests/|.hooks/)
|
||||
default_language_version:
|
||||
python: python3.10
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
name: Check Gitlab CI (unsafe)
|
||||
args: [--unsafe]
|
||||
files: .gitlab-ci.yml
|
||||
- id: check-yaml
|
||||
exclude: .gitlab-ci.yml
|
||||
- id: check-toml
|
||||
- id: detect-private-key
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=10000']
|
||||
- id: check-case-conflict
|
||||
- id: mixed-line-ending
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-pylint
|
||||
rev: v3.0.0a5
|
||||
hooks:
|
||||
- id: pylint
|
||||
language: system
|
||||
args:
|
||||
- --disable=C0111,R0903
|
||||
- --max-line-length=120
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-isort
|
||||
rev: v5.10.1
|
||||
hooks:
|
||||
- id: isort
|
||||
args:
|
||||
- --profile black
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.10.0
|
||||
hooks:
|
||||
- id: black
|
||||
# exclude: ^(docs/|notebooks/|data/|src/secrets/)
|
||||
args:
|
||||
- --line-length=120
|
||||
|
||||
- repo: https://github.com/compilerla/conventional-pre-commit
|
||||
rev: v3.6.0
|
||||
hooks:
|
||||
- id: conventional-pre-commit
|
||||
pass_filenames: false
|
||||
stages: [commit-msg]
|
||||
# args: [] # optional: list of Conventional Commits types to allow e.g. [feat, fix, ci, chore, test]
|
||||
@ -1 +0,0 @@
|
||||
3.10
|
||||
19
Dockerfile_tests
Normal file
19
Dockerfile_tests
Normal file
@ -0,0 +1,19 @@
|
||||
ARG BASE_ROOT="nexus.iqser.com:5001/red/"
|
||||
ARG VERSION_TAG="dev"
|
||||
|
||||
FROM ${BASE_ROOT}pyinfra:${VERSION_TAG}
|
||||
|
||||
EXPOSE 5000
|
||||
EXPOSE 8080
|
||||
|
||||
RUN python3 -m pip install coverage
|
||||
|
||||
# Make a directory for the service files and copy the service repo into the container.
|
||||
WORKDIR /app/service
|
||||
COPY . .
|
||||
|
||||
# Install module & dependencies
|
||||
RUN python3 -m pip install -e .
|
||||
RUN python3 -m pip install -r requirements.txt
|
||||
|
||||
CMD coverage run -m pytest test/ -x && coverage report -m && coverage xml
|
||||
85
Makefile
85
Makefile
@ -1,85 +0,0 @@
|
||||
.PHONY: \
|
||||
poetry in-project-venv dev-env use-env install install-dev tests \
|
||||
update-version sync-version-with-git \
|
||||
docker docker-build-run docker-build docker-run \
|
||||
docker-rm docker-rm-container docker-rm-image \
|
||||
pre-commit get-licenses prep-commit \
|
||||
docs sphinx_html sphinx_apidoc
|
||||
.DEFAULT_GOAL := run
|
||||
|
||||
export DOCKER=docker
|
||||
export DOCKERFILE=Dockerfile
|
||||
export IMAGE_NAME=rule_engine-image
|
||||
export CONTAINER_NAME=rule_engine-container
|
||||
export HOST_PORT=9999
|
||||
export CONTAINER_PORT=9999
|
||||
export PYTHON_VERSION=python3.8
|
||||
|
||||
# all commands should be executed in the root dir or the project,
|
||||
# specific environments should be deactivated
|
||||
|
||||
poetry: in-project-venv use-env dev-env
|
||||
|
||||
in-project-venv:
|
||||
poetry config virtualenvs.in-project true
|
||||
|
||||
use-env:
|
||||
poetry env use ${PYTHON_VERSION}
|
||||
|
||||
dev-env:
|
||||
poetry install --with dev
|
||||
|
||||
install:
|
||||
poetry add $(pkg)
|
||||
|
||||
install-dev:
|
||||
poetry add --dev $(pkg)
|
||||
|
||||
requirements:
|
||||
poetry export --without-hashes --output requirements.txt
|
||||
|
||||
update-version:
|
||||
poetry version prerelease
|
||||
|
||||
sync-version-with-git:
|
||||
git pull -p && poetry version $(git rev-list --tags --max-count=1 | git describe --tags --abbrev=0)
|
||||
|
||||
docker: docker-rm docker-build-run
|
||||
|
||||
docker-build-run: docker-build docker-run
|
||||
|
||||
docker-build:
|
||||
$(DOCKER) build \
|
||||
--no-cache --progress=plain \
|
||||
-t $(IMAGE_NAME) -f $(DOCKERFILE) .
|
||||
|
||||
docker-run:
|
||||
$(DOCKER) run -it --rm -p $(HOST_PORT):$(CONTAINER_PORT)/tcp --name $(CONTAINER_NAME) $(IMAGE_NAME) python app.py
|
||||
|
||||
docker-rm: docker-rm-container docker-rm-image
|
||||
|
||||
docker-rm-container:
|
||||
-$(DOCKER) rm $(CONTAINER_NAME)
|
||||
|
||||
docker-rm-image:
|
||||
-$(DOCKER) image rm $(IMAGE_NAME)
|
||||
|
||||
tests:
|
||||
poetry run pytest ./tests
|
||||
|
||||
prep-commit:
|
||||
docs get-license sync-version-with-git update-version pre-commit
|
||||
|
||||
pre-commit:
|
||||
pre-commit run --all-files
|
||||
|
||||
get-licenses:
|
||||
pip-licenses --format=json --order=license --with-urls > pkg-licenses.json
|
||||
|
||||
docs: sphinx_apidoc sphinx_html
|
||||
|
||||
sphinx_html:
|
||||
poetry run sphinx-build -b html docs/source/ docs/build/html -E -a
|
||||
|
||||
sphinx_apidoc:
|
||||
poetry run sphinx-apidoc -o ./docs/source/modules ./src/rule_engine
|
||||
248
README.md
248
README.md
@ -1,220 +1,106 @@
|
||||
# PyInfra
|
||||
# Infrastructure to deploy Research Projects
|
||||
|
||||
1. [ About ](#about)
|
||||
2. [ Configuration ](#configuration)
|
||||
3. [ Queue Manager ](#queue-manager)
|
||||
4. [ Module Installation ](#module-installation)
|
||||
5. [ Scripts ](#scripts)
|
||||
6. [ Tests ](#tests)
|
||||
7. [ Opentelemetry protobuf dependency hell ](#opentelemetry-protobuf-dependency-hell)
|
||||
|
||||
## About
|
||||
|
||||
Shared library for the research team, containing code related to infrastructure and communication with other services.
|
||||
Offers a simple interface for processing data and sending responses via AMQP, monitoring via Prometheus and storage
|
||||
access via S3 or Azure. Also export traces via OpenTelemetry for queue messages and webserver requests.
|
||||
|
||||
To start, see the [complete example](pyinfra/examples.py) which shows how to use all features of the service and can be
|
||||
imported and used directly for default research service pipelines (data ID in message, download data from storage,
|
||||
upload result while offering Prometheus monitoring, /health and /ready endpoints and multi tenancy support).
|
||||
The Infrastructure expects to be deployed in the same Pod / local environment as the analysis container and handles all outbound communication.
|
||||
|
||||
## Configuration
|
||||
|
||||
Configuration is done via `Dynaconf`. This means that you can use environment variables, a `.env` file or `.toml`
|
||||
file(s) to configure the service. You can also combine these methods. The precedence is
|
||||
`environment variables > .env > .toml`. It is recommended to load settings with the provided
|
||||
[`load_settings`](pyinfra/config/loader.py) function, which you can combine with the provided
|
||||
[`parse_args`](pyinfra/config/loader.py) function. This allows you to load settings from a `.toml` file or a folder with
|
||||
`.toml` files and override them with environment variables.
|
||||
A configuration is located in `/config.yaml`. All relevant variables can be configured via exporting environment variables.
|
||||
|
||||
The following table shows all necessary settings. You can find a preconfigured settings file for this service in
|
||||
bitbucket. These are the complete settings, you only need all if using all features of the service as described in
|
||||
the [complete example](pyinfra/examples.py).
|
||||
| Environment Variable | Default | Description |
|
||||
| ----------------------------- | ------------------------------ | --------------------------------------------------------------------- |
|
||||
| LOGGING_LEVEL_ROOT | DEBUG | Logging level for service logger |
|
||||
| PROBING_WEBSERVER_HOST | "0.0.0.0" | Probe webserver address |
|
||||
| PROBING_WEBSERVER_PORT | 8080 | Probe webserver port |
|
||||
| PROBING_WEBSERVER_MODE | production | Webserver mode: {development, production} |
|
||||
| RABBITMQ_HOST | localhost | RabbitMQ host address |
|
||||
| RABBITMQ_PORT | 5672 | RabbitMQ host port |
|
||||
| RABBITMQ_USERNAME | user | RabbitMQ username |
|
||||
| RABBITMQ_PASSWORD | bitnami | RabbitMQ password |
|
||||
| RABBITMQ_HEARTBEAT | 7200 | Controls AMQP heartbeat timeout in seconds |
|
||||
| REQUEST_QUEUE | request_queue | Requests to service |
|
||||
| RESPONSE_QUEUE | response_queue | Responses by service |
|
||||
| DEAD_LETTER_QUEUE | dead_letter_queue | Messages that failed to process |
|
||||
| ANALYSIS_ENDPOINT | "http://127.0.0.1:5000" | Endpoint for analysis container |
|
||||
| STORAGE_BACKEND | s3 | The type of storage to use {s3, azure} |
|
||||
| STORAGE_BUCKET | "redaction" | The bucket / container to pull files specified in queue requests from |
|
||||
| STORAGE_ENDPOINT | "http://127.0.0.1:9000" | Endpoint for s3 storage |
|
||||
| STORAGE_KEY | root | User for s3 storage |
|
||||
| STORAGE_SECRET | password | Password for s3 storage |
|
||||
| STORAGE_AZURECONNECTIONSTRING | "DefaultEndpointsProtocol=..." | Connection string for Azure storage |
|
||||
| STORAGE_AZURECONTAINERNAME | "redaction" | AKS container |
|
||||
|
||||
| Environment Variable | Internal / .toml Name | Description |
|
||||
| ------------------------------------------ | --------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| LOGGING\_\_LEVEL | logging.level | Log level |
|
||||
| DYNAMIC_TENANT_QUEUES\_\_ENABLED | dynamic_tenant_queues.enabled | Enable queues per tenant that are dynamically created mode |
|
||||
| METRICS\_\_PROMETHEUS\_\_ENABLED | metrics.prometheus.enabled | Enable Prometheus metrics collection |
|
||||
| METRICS\_\_PROMETHEUS\_\_PREFIX | metrics.prometheus.prefix | Prefix for Prometheus metrics (e.g. {product}-{service}) |
|
||||
| WEBSERVER\_\_HOST | webserver.host | Host of the webserver (offering e.g. /prometheus, /ready and /health endpoints) |
|
||||
| WEBSERVER\_\_PORT | webserver.port | Port of the webserver |
|
||||
| RABBITMQ\_\_HOST | rabbitmq.host | Host of the RabbitMQ server |
|
||||
| RABBITMQ\_\_PORT | rabbitmq.port | Port of the RabbitMQ server |
|
||||
| RABBITMQ\_\_USERNAME | rabbitmq.username | Username for the RabbitMQ server |
|
||||
| RABBITMQ\_\_PASSWORD | rabbitmq.password | Password for the RabbitMQ server |
|
||||
| RABBITMQ\_\_HEARTBEAT | rabbitmq.heartbeat | Heartbeat for the RabbitMQ server |
|
||||
| RABBITMQ\_\_CONNECTION_SLEEP | rabbitmq.connection_sleep | Sleep time intervals during message processing. Has to be a divider of heartbeat, and shouldn't be too big, since only in these intervals queue interactions happen (like receiving new messages) This is also the minimum time the service needs to process a message. |
|
||||
| RABBITMQ\_\_INPUT_QUEUE | rabbitmq.input_queue | Name of the input queue in single queue setting |
|
||||
| RABBITMQ\_\_OUTPUT_QUEUE | rabbitmq.output_queue | Name of the output queue in single queue setting |
|
||||
| RABBITMQ\_\_DEAD_LETTER_QUEUE | rabbitmq.dead_letter_queue | Name of the dead letter queue in single queue setting |
|
||||
| RABBITMQ\_\_TENANT_EVENT_QUEUE_SUFFIX | rabbitmq.tenant_event_queue_suffix | Suffix for the tenant event queue in multi tenant/queue setting |
|
||||
| RABBITMQ\_\_TENANT_EVENT_DLQ_SUFFIX | rabbitmq.tenant_event_dlq_suffix | Suffix for the dead letter queue in multi tenant/queue setting |
|
||||
| RABBITMQ\_\_TENANT_EXCHANGE_NAME | rabbitmq.tenant_exchange_name | Name of tenant exchange in multi tenant/queue setting |
|
||||
| RABBITMQ\_\_QUEUE_EXPIRATION_TIME | rabbitmq.queue_expiration_time | Time until queue expiration in multi tenant/queue setting |
|
||||
| RABBITMQ\_\_SERVICE_REQUEST_QUEUE_PREFIX | rabbitmq.service_request_queue_prefix | Service request queue prefix in multi tenant/queue setting |
|
||||
| RABBITMQ\_\_SERVICE_REQUEST_EXCHANGE_NAME | rabbitmq.service_request_exchange_name | Service request exchange name in multi tenant/queue setting |
|
||||
| RABBITMQ\_\_SERVICE_RESPONSE_EXCHANGE_NAME | rabbitmq.service_response_exchange_name | Service response exchange name in multi tenant/queue setting |
|
||||
| RABBITMQ\_\_SERVICE_DLQ_NAME | rabbitmq.service_dlq_name | Service dead letter queue name in multi tenant/queue setting |
|
||||
| STORAGE\_\_BACKEND | storage.backend | Storage backend to use (currently only "s3" and "azure" are supported) |
|
||||
| STORAGE\_\_S3\_\_BUCKET | storage.s3.bucket | Name of the S3 bucket |
|
||||
| STORAGE\_\_S3\_\_ENDPOINT | storage.s3.endpoint | Endpoint of the S3 server |
|
||||
| STORAGE\_\_S3\_\_KEY | storage.s3.key | Access key for the S3 server |
|
||||
| STORAGE\_\_S3\_\_SECRET | storage.s3.secret | Secret key for the S3 server |
|
||||
| STORAGE\_\_S3\_\_REGION | storage.s3.region | Region of the S3 server |
|
||||
| STORAGE\_\_AZURE\_\_CONTAINER | storage.azure.container_name | Name of the Azure container |
|
||||
| STORAGE\_\_AZURE\_\_CONNECTION_STRING | storage.azure.connection_string | Connection string for the Azure server |
|
||||
| STORAGE\_\_TENANT_SERVER\_\_PUBLIC_KEY | storage.tenant_server.public_key | Public key of the tenant server |
|
||||
| STORAGE\_\_TENANT_SERVER\_\_ENDPOINT | storage.tenant_server.endpoint | Endpoint of the tenant server |
|
||||
| TRACING\_\_ENABLED | tracing.enabled | Enable tracing |
|
||||
| TRACING\_\_TYPE | tracing.type | Tracing mode - possible values: "opentelemetry", "azure_monitor" (Excpects APPLICATIONINSIGHTS_CONNECTION_STRING environment variable.) |
|
||||
| TRACING\_\_OPENTELEMETRY\_\_ENDPOINT | tracing.opentelemetry.endpoint | Endpoint to which OpenTelemetry traces are exported |
|
||||
| TRACING\_\_OPENTELEMETRY\_\_SERVICE_NAME | tracing.opentelemetry.service_name | Name of the service as displayed in the traces collected |
|
||||
| TRACING\_\_OPENTELEMETRY\_\_EXPORTER | tracing.opentelemetry.exporter | Name of exporter |
|
||||
| KUBERNETES\_\_POD_NAME | kubernetes.pod_name | Service pod name |
|
||||
## Response Format
|
||||
|
||||
## Setup
|
||||
**IMPORTANT** you need to set the following environment variables before running the setup script:
|
||||
- ``$NEXUS_USER`` your Nexus user (usually equal to firstname.lastname@knecon.com)
|
||||
- ``$NEXUS_PASSWORD`` your Nexus password (usually equal to your Azure Login)
|
||||
|
||||
```shell
|
||||
# create venv and activate it
|
||||
source ./scripts/setup/devenvsetup.sh {{ cookiecutter.python_version }} $NEXUS_USER $NEXUS_PASSWORD
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
### OpenTelemetry
|
||||
|
||||
Open telemetry (vis its Python SDK) is set up to be as unobtrusive as possible; for typical use cases it can be
|
||||
configured
|
||||
from environment variables, without additional work in the microservice app, although additional confiuration is
|
||||
possible.
|
||||
|
||||
`TRACING__OPENTELEMETRY__ENDPOINT` should typically be set
|
||||
to `http://otel-collector-opentelemetry-collector.otel-collector:4318/v1/traces`.
|
||||
|
||||
## Queue Manager
|
||||
|
||||
The queue manager is responsible for consuming messages from the input queue, processing them and sending the response
|
||||
to the output queue. The default callback also downloads data from the storage and uploads the result to the storage.
|
||||
The response message does not contain the data itself, but the identifiers from the input message (including headers
|
||||
beginning with "X-").
|
||||
|
||||
### Standalone Usage
|
||||
|
||||
```python
|
||||
from pyinfra.queue.manager import QueueManager
|
||||
from pyinfra.queue.callback import make_download_process_upload_callback, DataProcessor
|
||||
from pyinfra.config.loader import load_settings
|
||||
|
||||
settings = load_settings("path/to/settings")
|
||||
processing_function: DataProcessor # function should expect a dict (json) or bytes (pdf) as input and should return a json serializable object.
|
||||
|
||||
queue_manager = QueueManager(settings)
|
||||
callback = make_download_process_upload_callback(processing_function, settings)
|
||||
queue_manager.start_consuming(make_download_process_upload_callback(callback, settings))
|
||||
```
|
||||
|
||||
### Usage in a Service
|
||||
|
||||
This is the recommended way to use the module. This includes the webserver, Prometheus metrics and health endpoints.
|
||||
Custom endpoints can be added by adding a new route to the `app` object beforehand. Settings are loaded from files
|
||||
specified as CLI arguments (e.g. `--settings-path path/to/settings.toml`). The values can also be set or overriden via
|
||||
environment variables (e.g. `LOGGING__LEVEL=DEBUG`).
|
||||
|
||||
The callback can be replaced with a custom one, for example if the data to process is contained in the message itself
|
||||
and not on the storage.
|
||||
|
||||
```python
|
||||
from pyinfra.config.loader import load_settings, parse_settings_path
|
||||
from pyinfra.examples import start_standard_queue_consumer
|
||||
from pyinfra.queue.callback import make_download_process_upload_callback, DataProcessor
|
||||
|
||||
processing_function: DataProcessor
|
||||
|
||||
arguments = parse_settings_path()
|
||||
settings = load_settings(arguments.settings_path)
|
||||
|
||||
callback = make_download_process_upload_callback(processing_function, settings)
|
||||
start_standard_queue_consumer(callback, settings) # optionally also pass a fastAPI app object with preconfigured routes
|
||||
```
|
||||
|
||||
### AMQP input message:
|
||||
|
||||
Either use the legacy format with dossierId and fileId as strings or the new format where absolute paths are used.
|
||||
All headers beginning with "X-" are forwarded to the message processor, and returned in the response message (e.g.
|
||||
"X-TENANT-ID" is used to acquire storage information for the tenant).
|
||||
### Expected AMQP input message:
|
||||
|
||||
```json
|
||||
{
|
||||
"targetFilePath": "",
|
||||
"responseFilePath": ""
|
||||
"dossierId": "",
|
||||
"fileId": "",
|
||||
"targetFileExtension": "",
|
||||
"responseFileExtension": ""
|
||||
}
|
||||
```
|
||||
|
||||
or
|
||||
Optionally, the input message can contain a field with the key `"operations"`.
|
||||
|
||||
### AMQP output message:
|
||||
|
||||
```json
|
||||
{
|
||||
"dossierId": "",
|
||||
"fileId": "",
|
||||
"targetFileExtension": "",
|
||||
"responseFileExtension": ""
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
## Module Installation
|
||||
## Development
|
||||
|
||||
Add the respective version of the pyinfra package to your pyproject.toml file. Make sure to add our gitlab registry as a
|
||||
source.
|
||||
For now, all internal packages used by pyinfra also have to be added to the pyproject.toml file (namely kn-utils).
|
||||
Execute `poetry lock` and `poetry install` to install the packages.
|
||||
Either run `src/serve.py` or the built Docker image.
|
||||
|
||||
You can look up the latest version of the package in
|
||||
the [gitlab registry](https://gitlab.knecon.com/knecon/research/pyinfra/-/packages).
|
||||
For the used versions of internal dependencies, please refer to the [pyproject.toml](pyproject.toml) file.
|
||||
### Setup
|
||||
|
||||
```toml
|
||||
[tool.poetry.dependencies]
|
||||
pyinfra = { version = "x.x.x", source = "gitlab-research" }
|
||||
kn-utils = { version = "x.x.x", source = "gitlab-research" }
|
||||
|
||||
[[tool.poetry.source]]
|
||||
name = "gitlab-research"
|
||||
url = "https://gitlab.knecon.com/api/v4/groups/19/-/packages/pypi/simple"
|
||||
priority = "explicit"
|
||||
```
|
||||
|
||||
## Scripts
|
||||
|
||||
### Run pyinfra locally
|
||||
|
||||
**Shell 1**: Start minio and rabbitmq containers
|
||||
Install module.
|
||||
|
||||
```bash
|
||||
$ cd tests && docker compose up
|
||||
pip install -e .
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
**Shell 2**: Start pyinfra with callback mock
|
||||
or build docker image.
|
||||
|
||||
```bash
|
||||
$ python scripts/start_pyinfra.py
|
||||
docker build -f Dockerfile -t pyinfra .
|
||||
```
|
||||
|
||||
**Shell 3**: Upload dummy content on storage and publish message
|
||||
### Usage
|
||||
|
||||
**Shell 1:** Start a MinIO and a RabbitMQ docker container.
|
||||
|
||||
```bash
|
||||
$ python scripts/send_request.py
|
||||
docker-compose up
|
||||
```
|
||||
|
||||
## Tests
|
||||
**Shell 2:** Add files to the local minio storage.
|
||||
|
||||
Tests require a running minio and rabbitmq container, meaning you have to run `docker compose up` in the tests folder
|
||||
before running the tests.
|
||||
```bash
|
||||
python scripts/manage_minio.py add <MinIO target folder> -d path/to/a/folder/with/PDFs
|
||||
```
|
||||
|
||||
## OpenTelemetry Protobuf Dependency Hell
|
||||
**Shell 2:** Run pyinfra-server.
|
||||
|
||||
**Note**: Status 2025/01/09: the currently used `opentelemetry-exporter-otlp-proto-http` version `1.25.0` requires
|
||||
a `protobuf` version < `5.x.x` and is not compatible with the latest protobuf version `5.27.x`. This is an [open issue](https://github.com/open-telemetry/opentelemetry-python/issues/3958) in opentelemetry, because [support for 4.25.x ends in Q2 '25](https://protobuf.dev/support/version-support/#python).
|
||||
Therefore, we should keep this in mind and update the dependency once opentelemetry includes support for `protobuf 5.27.x`.
|
||||
```bash
|
||||
python src/serve.py
|
||||
```
|
||||
or as container:
|
||||
|
||||
```bash
|
||||
docker run --net=host pyinfra
|
||||
```
|
||||
|
||||
**Shell 3:** Run analysis-container.
|
||||
|
||||
**Shell 4:** Start a client that sends requests to process PDFs from the MinIO store and annotates these PDFs according to the service responses.
|
||||
```bash
|
||||
python scripts/mock_client.py
|
||||
```
|
||||
|
||||
40
bamboo-specs/pom.xml
Normal file
40
bamboo-specs/pom.xml
Normal file
@ -0,0 +1,40 @@
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>com.atlassian.bamboo</groupId>
|
||||
<artifactId>bamboo-specs-parent</artifactId>
|
||||
<version>7.1.2</version>
|
||||
<relativePath/>
|
||||
</parent>
|
||||
|
||||
<artifactId>bamboo-specs</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<properties>
|
||||
<sonar.skip>true</sonar.skip>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.atlassian.bamboo</groupId>
|
||||
<artifactId>bamboo-specs-api</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.atlassian.bamboo</groupId>
|
||||
<artifactId>bamboo-specs</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Test dependencies -->
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<!-- run 'mvn test' to perform offline validation of the plan -->
|
||||
<!-- run 'mvn -Ppublish-specs' to upload the plan to your Bamboo server -->
|
||||
</project>
|
||||
148
bamboo-specs/src/main/java/buildjob/PlanSpec.java
Normal file
148
bamboo-specs/src/main/java/buildjob/PlanSpec.java
Normal file
@ -0,0 +1,148 @@
|
||||
package buildjob;
|
||||
|
||||
import com.atlassian.bamboo.specs.api.BambooSpec;
|
||||
import com.atlassian.bamboo.specs.api.builders.BambooKey;
|
||||
import com.atlassian.bamboo.specs.api.builders.docker.DockerConfiguration;
|
||||
import com.atlassian.bamboo.specs.api.builders.permission.PermissionType;
|
||||
import com.atlassian.bamboo.specs.api.builders.permission.Permissions;
|
||||
import com.atlassian.bamboo.specs.api.builders.permission.PlanPermissions;
|
||||
import com.atlassian.bamboo.specs.api.builders.plan.Job;
|
||||
import com.atlassian.bamboo.specs.api.builders.plan.Plan;
|
||||
import com.atlassian.bamboo.specs.api.builders.plan.PlanIdentifier;
|
||||
import com.atlassian.bamboo.specs.api.builders.plan.Stage;
|
||||
import com.atlassian.bamboo.specs.api.builders.plan.branches.BranchCleanup;
|
||||
import com.atlassian.bamboo.specs.api.builders.plan.branches.PlanBranchManagement;
|
||||
import com.atlassian.bamboo.specs.api.builders.project.Project;
|
||||
import com.atlassian.bamboo.specs.builders.task.CheckoutItem;
|
||||
import com.atlassian.bamboo.specs.builders.task.InjectVariablesTask;
|
||||
import com.atlassian.bamboo.specs.builders.task.ScriptTask;
|
||||
import com.atlassian.bamboo.specs.builders.task.VcsCheckoutTask;
|
||||
import com.atlassian.bamboo.specs.builders.task.CleanWorkingDirectoryTask;
|
||||
import com.atlassian.bamboo.specs.builders.task.VcsTagTask;
|
||||
import com.atlassian.bamboo.specs.builders.trigger.BitbucketServerTrigger;
|
||||
import com.atlassian.bamboo.specs.model.task.InjectVariablesScope;
|
||||
import com.atlassian.bamboo.specs.api.builders.Variable;
|
||||
import com.atlassian.bamboo.specs.util.BambooServer;
|
||||
import com.atlassian.bamboo.specs.builders.task.ScriptTask;
|
||||
import com.atlassian.bamboo.specs.model.task.ScriptTaskProperties.Location;
|
||||
|
||||
/**
|
||||
* Plan configuration for Bamboo.
|
||||
* Learn more on: <a href="https://confluence.atlassian.com/display/BAMBOO/Bamboo+Specs">https://confluence.atlassian.com/display/BAMBOO/Bamboo+Specs</a>
|
||||
*/
|
||||
@BambooSpec
|
||||
public class PlanSpec {
|
||||
|
||||
private static final String SERVICE_NAME = "pyinfra";
|
||||
|
||||
private static final String SERVICE_KEY = SERVICE_NAME.toUpperCase().replaceAll("-","");
|
||||
|
||||
/**
|
||||
* Run main to publish plan on Bamboo
|
||||
*/
|
||||
public static void main(final String[] args) throws Exception {
|
||||
//By default credentials are read from the '.credentials' file.
|
||||
BambooServer bambooServer = new BambooServer("http://localhost:8085");
|
||||
|
||||
Plan plan = new PlanSpec().createBuildPlan();
|
||||
bambooServer.publish(plan);
|
||||
PlanPermissions planPermission = new PlanSpec().createPlanPermission(plan.getIdentifier());
|
||||
bambooServer.publish(planPermission);
|
||||
}
|
||||
|
||||
private PlanPermissions createPlanPermission(PlanIdentifier planIdentifier) {
|
||||
Permissions permission = new Permissions()
|
||||
.userPermissions("atlbamboo", PermissionType.EDIT, PermissionType.VIEW, PermissionType.ADMIN, PermissionType.CLONE, PermissionType.BUILD)
|
||||
.groupPermissions("research", PermissionType.EDIT, PermissionType.VIEW, PermissionType.CLONE, PermissionType.BUILD)
|
||||
.groupPermissions("Development", PermissionType.EDIT, PermissionType.VIEW, PermissionType.CLONE, PermissionType.BUILD)
|
||||
.groupPermissions("QA", PermissionType.EDIT, PermissionType.VIEW, PermissionType.CLONE, PermissionType.BUILD)
|
||||
.loggedInUserPermissions(PermissionType.VIEW)
|
||||
.anonymousUserPermissionView();
|
||||
return new PlanPermissions(planIdentifier.getProjectKey(), planIdentifier.getPlanKey()).permissions(permission);
|
||||
}
|
||||
|
||||
private Project project() {
|
||||
return new Project()
|
||||
.name("RED")
|
||||
.key(new BambooKey("RED"));
|
||||
}
|
||||
|
||||
public Plan createBuildPlan() {
|
||||
return new Plan(
|
||||
project(),
|
||||
SERVICE_NAME, new BambooKey(SERVICE_KEY))
|
||||
.description("Build for pyinfra")
|
||||
.stages(
|
||||
new Stage("Sonar Stage")
|
||||
.jobs(
|
||||
new Job("Sonar Job", new BambooKey("SONAR"))
|
||||
.tasks(
|
||||
new CleanWorkingDirectoryTask()
|
||||
.description("Clean working directory.")
|
||||
.enabled(true),
|
||||
new VcsCheckoutTask()
|
||||
.description("Checkout default repository.")
|
||||
.checkoutItems(new CheckoutItem().defaultRepository()),
|
||||
new ScriptTask()
|
||||
.description("Set config and keys.")
|
||||
.inlineBody("mkdir -p ~/.ssh\n" +
|
||||
"echo \"${bamboo.bamboo_agent_ssh}\" | base64 -d >> ~/.ssh/id_rsa\n" +
|
||||
"echo \"host vector.iqser.com\" > ~/.ssh/config\n" +
|
||||
"echo \" user bamboo-agent\" >> ~/.ssh/config\n" +
|
||||
"chmod 600 ~/.ssh/config ~/.ssh/id_rsa"),
|
||||
new ScriptTask()
|
||||
.description("Run Sonarqube scan.")
|
||||
.location(Location.FILE)
|
||||
.fileFromPath("bamboo-specs/src/main/resources/scripts/sonar-scan.sh")
|
||||
.argument(SERVICE_NAME))
|
||||
.dockerConfiguration(
|
||||
new DockerConfiguration()
|
||||
.image("nexus.iqser.com:5001/infra/release_build:4.2.0")
|
||||
.volume("/var/run/docker.sock", "/var/run/docker.sock"))),
|
||||
new Stage("Licence Stage")
|
||||
.jobs(
|
||||
new Job("Git Tag Job", new BambooKey("GITTAG"))
|
||||
.tasks(
|
||||
new VcsCheckoutTask()
|
||||
.description("Checkout default repository.")
|
||||
.checkoutItems(new CheckoutItem().defaultRepository()),
|
||||
new ScriptTask()
|
||||
.description("Build git tag.")
|
||||
.location(Location.FILE)
|
||||
.fileFromPath("bamboo-specs/src/main/resources/scripts/git-tag.sh"),
|
||||
new InjectVariablesTask()
|
||||
.description("Inject git tag.")
|
||||
.path("git.tag")
|
||||
.namespace("g")
|
||||
.scope(InjectVariablesScope.LOCAL),
|
||||
new VcsTagTask()
|
||||
.description("${bamboo.g.gitTag}")
|
||||
.tagName("${bamboo.g.gitTag}")
|
||||
.defaultRepository())
|
||||
.dockerConfiguration(
|
||||
new DockerConfiguration()
|
||||
.image("nexus.iqser.com:5001/infra/release_build:4.4.1")),
|
||||
new Job("Licence Job", new BambooKey("LICENCE"))
|
||||
.enabled(false)
|
||||
.tasks(
|
||||
new VcsCheckoutTask()
|
||||
.description("Checkout default repository.")
|
||||
.checkoutItems(new CheckoutItem().defaultRepository()),
|
||||
new ScriptTask()
|
||||
.description("Build licence.")
|
||||
.location(Location.FILE)
|
||||
.fileFromPath("bamboo-specs/src/main/resources/scripts/create-licence.sh"))
|
||||
.dockerConfiguration(
|
||||
new DockerConfiguration()
|
||||
.image("nexus.iqser.com:5001/infra/maven:3.6.2-jdk-13-3.0.0")
|
||||
.volume("/etc/maven/settings.xml", "/usr/share/maven/ref/settings.xml")
|
||||
.volume("/var/run/docker.sock", "/var/run/docker.sock"))))
|
||||
.linkedRepositories("RR / " + SERVICE_NAME)
|
||||
.triggers(new BitbucketServerTrigger())
|
||||
.planBranchManagement(new PlanBranchManagement()
|
||||
.createForVcsBranch()
|
||||
.delete(new BranchCleanup()
|
||||
.whenInactiveInRepositoryAfterDays(14))
|
||||
.notificationForCommitters());
|
||||
}
|
||||
}
|
||||
19
bamboo-specs/src/main/resources/scripts/create-licence.sh
Executable file
19
bamboo-specs/src/main/resources/scripts/create-licence.sh
Executable file
@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if [[ \"${bamboo_version_tag}\" != \"dev\" ]]
|
||||
then
|
||||
${bamboo_capability_system_builder_mvn3_Maven_3}/bin/mvn \
|
||||
-f ${bamboo_build_working_directory}/pom.xml \
|
||||
versions:set \
|
||||
-DnewVersion=${bamboo_version_tag}
|
||||
|
||||
${bamboo_capability_system_builder_mvn3_Maven_3}/bin/mvn \
|
||||
-f ${bamboo_build_working_directory}/pom.xml \
|
||||
-B clean deploy \
|
||||
-e -DdeployAtEnd=true \
|
||||
-Dmaven.wagon.http.ssl.insecure=true \
|
||||
-Dmaven.wagon.http.ssl.allowall=true \
|
||||
-Dmaven.wagon.http.ssl.ignore.validity.dates=true \
|
||||
-DaltDeploymentRepository=iqser_release::default::https://nexus.iqser.com/repository/gin4-platform-releases
|
||||
fi
|
||||
9
bamboo-specs/src/main/resources/scripts/git-tag.sh
Executable file
9
bamboo-specs/src/main/resources/scripts/git-tag.sh
Executable file
@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if [[ "${bamboo_version_tag}" = "dev" ]]
|
||||
then
|
||||
echo "gitTag=${bamboo_planRepository_1_branch}_${bamboo_buildNumber}" > git.tag
|
||||
else
|
||||
echo "gitTag=${bamboo_version_tag}" > git.tag
|
||||
fi
|
||||
56
bamboo-specs/src/main/resources/scripts/sonar-scan.sh
Executable file
56
bamboo-specs/src/main/resources/scripts/sonar-scan.sh
Executable file
@ -0,0 +1,56 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
export JAVA_HOME=/usr/bin/sonar-scanner/jre
|
||||
|
||||
python3 -m venv build_venv
|
||||
source build_venv/bin/activate
|
||||
python3 -m pip install --upgrade pip
|
||||
python3 -m pip install dependency-check
|
||||
python3 -m pip install docker-compose
|
||||
python3 -m pip install coverage
|
||||
|
||||
# This is disabled since there are currently no tests in this project.
|
||||
# If tests are added this can be enabled again
|
||||
# echo "coverage report generation"
|
||||
# bash run_tests.sh
|
||||
|
||||
# if [ ! -f reports/coverage.xml ]
|
||||
# then
|
||||
# exit 1
|
||||
# fi
|
||||
|
||||
SERVICE_NAME=$1
|
||||
|
||||
echo "dependency-check:aggregate"
|
||||
mkdir -p reports
|
||||
dependency-check --enableExperimental -f JSON -f XML \
|
||||
--disableAssembly -s . -o reports --project $SERVICE_NAME --exclude ".git/**" --exclude "venv/**" \
|
||||
--exclude "build_venv/**" --exclude "**/__pycache__/**" --exclude "bamboo-specs/**"
|
||||
|
||||
if [[ -z "${bamboo_repository_pr_key}" ]]
|
||||
then
|
||||
echo "Sonar Scan for branch: ${bamboo_planRepository_1_branch}"
|
||||
/usr/bin/sonar-scanner/bin/sonar-scanner -X\
|
||||
-Dsonar.projectKey=RED_$SERVICE_NAME \
|
||||
-Dsonar.host.url=https://sonarqube.iqser.com \
|
||||
-Dsonar.login=${bamboo_sonarqube_api_token_secret} \
|
||||
-Dsonar.dependencyCheck.jsonReportPath=reports/dependency-check-report.json \
|
||||
-Dsonar.dependencyCheck.xmlReportPath=reports/dependency-check-report.xml \
|
||||
-Dsonar.dependencyCheck.htmlReportPath=reports/dependency-check-report.html \
|
||||
-Dsonar.python.coverage.reportPaths=reports/coverage.xml
|
||||
|
||||
else
|
||||
echo "Sonar Scan for PR with key1: ${bamboo_repository_pr_key}"
|
||||
/usr/bin/sonar-scanner/bin/sonar-scanner \
|
||||
-Dsonar.projectKey=RED_$SERVICE_NAME \
|
||||
-Dsonar.host.url=https://sonarqube.iqser.com \
|
||||
-Dsonar.login=${bamboo_sonarqube_api_token_secret} \
|
||||
-Dsonar.pullrequest.key=${bamboo_repository_pr_key} \
|
||||
-Dsonar.pullrequest.branch=${bamboo_repository_pr_sourceBranch} \
|
||||
-Dsonar.pullrequest.base=${bamboo_repository_pr_targetBranch} \
|
||||
-Dsonar.dependencyCheck.jsonReportPath=reports/dependency-check-report.json \
|
||||
-Dsonar.dependencyCheck.xmlReportPath=reports/dependency-check-report.xml \
|
||||
-Dsonar.dependencyCheck.htmlReportPath=reports/dependency-check-report.html \
|
||||
-Dsonar.python.coverage.reportPaths=reports/coverage.xml
|
||||
fi
|
||||
16
bamboo-specs/src/test/java/buildjob/PlanSpecTest.java
Normal file
16
bamboo-specs/src/test/java/buildjob/PlanSpecTest.java
Normal file
@ -0,0 +1,16 @@
|
||||
package buildjob;
|
||||
|
||||
|
||||
import com.atlassian.bamboo.specs.api.builders.plan.Plan;
|
||||
import com.atlassian.bamboo.specs.api.exceptions.PropertiesValidationException;
|
||||
import com.atlassian.bamboo.specs.api.util.EntityPropertiesBuilders;
|
||||
import org.junit.Test;
|
||||
|
||||
public class PlanSpecTest {
|
||||
@Test
|
||||
public void checkYourPlanOffline() throws PropertiesValidationException {
|
||||
Plan plan = new PlanSpec().createDockerBuildPlan();
|
||||
|
||||
EntityPropertiesBuilders.build(plan);
|
||||
}
|
||||
}
|
||||
6949
poetry.lock
generated
6949
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1 +1,3 @@
|
||||
from pyinfra import k8s_probes, queue, storage, config
|
||||
|
||||
__all__ = ["k8s_probes", "queue", "storage", "config"]
|
||||
72
pyinfra/config.py
Normal file
72
pyinfra/config.py
Normal file
@ -0,0 +1,72 @@
|
||||
from os import environ
|
||||
|
||||
|
||||
def read_from_environment(environment_variable_name, default_value):
|
||||
return environ.get(environment_variable_name, default_value)
|
||||
|
||||
|
||||
class Config(object):
|
||||
def __init__(self):
|
||||
# Logging level for service logger
|
||||
self.logging_level_root = read_from_environment("LOGGING_LEVEL_ROOT", "DEBUG")
|
||||
|
||||
# RabbitMQ host address
|
||||
self.rabbitmq_host = read_from_environment("RABBITMQ_HOST", "localhost")
|
||||
|
||||
# RabbitMQ host port
|
||||
self.rabbitmq_port = read_from_environment("RABBITMQ_PORT", "5672")
|
||||
|
||||
# RabbitMQ username
|
||||
self.rabbitmq_username = read_from_environment("RABBITMQ_USERNAME", "user")
|
||||
|
||||
# RabbitMQ password
|
||||
self.rabbitmq_password = read_from_environment("RABBITMQ_PASSWORD", "bitnami")
|
||||
|
||||
# Controls AMQP heartbeat timeout in seconds
|
||||
self.rabbitmq_heartbeat = read_from_environment("RABBITMQ_HEARTBEAT", "60")
|
||||
|
||||
# Controls AMQP connection sleep timer in seconds
|
||||
# important for heartbeat to come through while main function runs on other thread
|
||||
self.rabbitmq_connection_sleep = read_from_environment("RABBITMQ_CONNECTION_SLEEP", 5)
|
||||
|
||||
# Queue name for requests to the service
|
||||
self.request_queue = read_from_environment("REQUEST_QUEUE", "request_queue")
|
||||
|
||||
# Queue name for responses by service
|
||||
self.response_queue = read_from_environment("RESPONSE_QUEUE", "response_queue")
|
||||
|
||||
# Queue name for failed messages
|
||||
self.dead_letter_queue = read_from_environment("DEAD_LETTER_QUEUE", "dead_letter_queue")
|
||||
|
||||
# The type of storage to use {s3, azure}
|
||||
self.storage_backend = read_from_environment("STORAGE_BACKEND", "s3")
|
||||
|
||||
# The bucket / container to pull files specified in queue requests from
|
||||
if self.storage_backend == "s3":
|
||||
self.storage_bucket = read_from_environment("STORAGE_BUCKET_NAME", "redaction")
|
||||
else:
|
||||
self.storage_bucket = read_from_environment("STORAGE_AZURECONTAINERNAME", "redaction")
|
||||
|
||||
# Endpoint for s3 storage
|
||||
self.storage_endpoint = read_from_environment("STORAGE_ENDPOINT", "http://127.0.0.1:9000")
|
||||
|
||||
# User for s3 storage
|
||||
self.storage_key = read_from_environment("STORAGE_KEY", "root")
|
||||
|
||||
# Password for s3 storage
|
||||
self.storage_secret = read_from_environment("STORAGE_SECRET", "password")
|
||||
|
||||
# Region for s3 storage
|
||||
self.storage_region = read_from_environment("STORAGE_REGION", "eu-central-1")
|
||||
|
||||
# Connection string for Azure storage
|
||||
self.storage_azureconnectionstring = read_from_environment(
|
||||
"STORAGE_AZURECONNECTIONSTRING", "DefaultEndpointsProtocol=..."
|
||||
)
|
||||
|
||||
# Value to see if we should write a consumer token to a file
|
||||
self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False")
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
return Config()
|
||||
@ -1,133 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from dynaconf import Dynaconf, ValidationError, Validator
|
||||
from funcy import lflatten
|
||||
from kn_utils.logging import logger
|
||||
|
||||
# This path is ment for testing purposes and convenience. It probably won't reflect the actual root path when pyinfra is
|
||||
# installed as a package, so don't use it in production code, but define your own root path as described in load config.
|
||||
local_pyinfra_root_path = Path(__file__).parents[2]
|
||||
|
||||
|
||||
def load_settings(
|
||||
settings_path: Union[str, Path, list] = "config/",
|
||||
root_path: Union[str, Path] = None,
|
||||
validators: list[Validator] = None,
|
||||
):
|
||||
"""Load settings from .toml files, .env and environment variables. Also ensures a ROOT_PATH environment variable is
|
||||
set. If ROOT_PATH is not set and no root_path argument is passed, the current working directory is used as root.
|
||||
Settings paths can be a single .toml file, a folder containing .toml files or a list of .toml files and folders.
|
||||
If a ROOT_PATH environment variable is set, it is not overwritten by the root_path argument.
|
||||
If a folder is passed, all .toml files in the folder are loaded. If settings path is None, only .env and
|
||||
environment variables are loaded. If settings_path are relative paths, they are joined with the root_path argument.
|
||||
"""
|
||||
|
||||
root_path = get_or_set_root_path(root_path)
|
||||
validators = validators or get_pyinfra_validators()
|
||||
|
||||
settings_files = normalize_to_settings_files(settings_path, root_path)
|
||||
|
||||
settings = Dynaconf(
|
||||
load_dotenv=True,
|
||||
envvar_prefix=False,
|
||||
settings_files=settings_files,
|
||||
)
|
||||
|
||||
validate_settings(settings, validators)
|
||||
logger.info("Settings loaded and validated.")
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
def normalize_to_settings_files(settings_path: Union[str, Path, list], root_path: Union[str, Path]):
|
||||
if settings_path is None:
|
||||
logger.info("No settings path specified, only loading .env end ENVs.")
|
||||
settings_files = []
|
||||
elif isinstance(settings_path, str) or isinstance(settings_path, Path):
|
||||
settings_files = [settings_path]
|
||||
elif isinstance(settings_path, list):
|
||||
settings_files = settings_path
|
||||
else:
|
||||
raise ValueError(f"Invalid settings path: {settings_path=}")
|
||||
|
||||
settings_files = lflatten(map(partial(_normalize_and_verify, root_path=root_path), settings_files))
|
||||
logger.debug(f"Normalized settings files: {settings_files}")
|
||||
|
||||
return settings_files
|
||||
|
||||
|
||||
def _normalize_and_verify(settings_path: Path, root_path: Path):
|
||||
settings_path = Path(settings_path)
|
||||
root_path = Path(root_path)
|
||||
|
||||
if not settings_path.is_absolute():
|
||||
logger.debug(f"Settings path is not absolute, joining with root path: {root_path}")
|
||||
settings_path = root_path / settings_path
|
||||
|
||||
if settings_path.is_dir():
|
||||
logger.debug(f"Settings path is a directory, loading all .toml files in the directory: {settings_path}")
|
||||
settings_files = list(settings_path.glob("*.toml"))
|
||||
elif settings_path.is_file():
|
||||
logger.debug(f"Settings path is a file, loading specified file: {settings_path}")
|
||||
settings_files = [settings_path]
|
||||
else:
|
||||
raise ValueError(f"Invalid settings path: {settings_path=}, {root_path=}")
|
||||
|
||||
return settings_files
|
||||
|
||||
|
||||
def get_or_set_root_path(root_path: Union[str, Path] = None):
|
||||
env_root_path = os.environ.get("ROOT_PATH")
|
||||
|
||||
if env_root_path:
|
||||
root_path = env_root_path
|
||||
logger.debug(f"'ROOT_PATH' environment variable is set to {root_path}.")
|
||||
|
||||
elif root_path:
|
||||
logger.info(f"'ROOT_PATH' environment variable is not set, setting to {root_path}.")
|
||||
os.environ["ROOT_PATH"] = str(root_path)
|
||||
|
||||
else:
|
||||
root_path = Path.cwd()
|
||||
logger.info(f"'ROOT_PATH' environment variable is not set, defaulting to working directory {root_path}.")
|
||||
os.environ["ROOT_PATH"] = str(root_path)
|
||||
|
||||
return root_path
|
||||
|
||||
|
||||
def get_pyinfra_validators():
|
||||
import pyinfra.config.validators
|
||||
|
||||
return lflatten(
|
||||
validator for validator in pyinfra.config.validators.__dict__.values() if isinstance(validator, list)
|
||||
)
|
||||
|
||||
|
||||
def validate_settings(settings: Dynaconf, validators):
|
||||
settings_valid = True
|
||||
|
||||
for validator in validators:
|
||||
try:
|
||||
validator.validate(settings)
|
||||
except ValidationError as e:
|
||||
settings_valid = False
|
||||
logger.warning(e)
|
||||
|
||||
if not settings_valid:
|
||||
raise ValidationError("Settings validation failed.")
|
||||
|
||||
logger.debug("Settings validated.")
|
||||
|
||||
|
||||
def parse_settings_path():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"settings_path",
|
||||
help="Path to settings file(s) or folder(s). Must be .toml file(s) or a folder(s) containing .toml files.",
|
||||
nargs="+",
|
||||
)
|
||||
return parser.parse_args().settings_path
|
||||
@ -1,57 +0,0 @@
|
||||
from dynaconf import Validator
|
||||
|
||||
queue_manager_validators = [
|
||||
Validator("rabbitmq.host", must_exist=True, is_type_of=str),
|
||||
Validator("rabbitmq.port", must_exist=True, is_type_of=int),
|
||||
Validator("rabbitmq.username", must_exist=True, is_type_of=str),
|
||||
Validator("rabbitmq.password", must_exist=True, is_type_of=str),
|
||||
Validator("rabbitmq.heartbeat", must_exist=True, is_type_of=int),
|
||||
Validator("rabbitmq.connection_sleep", must_exist=True, is_type_of=int),
|
||||
Validator("rabbitmq.input_queue", must_exist=True, is_type_of=str),
|
||||
Validator("rabbitmq.output_queue", must_exist=True, is_type_of=str),
|
||||
Validator("rabbitmq.dead_letter_queue", must_exist=True, is_type_of=str),
|
||||
]
|
||||
|
||||
azure_storage_validators = [
|
||||
Validator("storage.azure.connection_string", must_exist=True, is_type_of=str),
|
||||
Validator("storage.azure.container", must_exist=True, is_type_of=str),
|
||||
]
|
||||
|
||||
s3_storage_validators = [
|
||||
Validator("storage.s3.endpoint", must_exist=True, is_type_of=str),
|
||||
Validator("storage.s3.key", must_exist=True, is_type_of=str),
|
||||
Validator("storage.s3.secret", must_exist=True, is_type_of=str),
|
||||
Validator("storage.s3.region", must_exist=True, is_type_of=str),
|
||||
Validator("storage.s3.bucket", must_exist=True, is_type_of=str),
|
||||
]
|
||||
|
||||
storage_validators = [
|
||||
Validator("storage.backend", must_exist=True, is_type_of=str),
|
||||
]
|
||||
|
||||
multi_tenant_storage_validators = [
|
||||
Validator("storage.tenant_server.endpoint", must_exist=True, is_type_of=str),
|
||||
Validator("storage.tenant_server.public_key", must_exist=True, is_type_of=str),
|
||||
]
|
||||
|
||||
|
||||
prometheus_validators = [
|
||||
Validator("metrics.prometheus.prefix", must_exist=True, is_type_of=str),
|
||||
Validator("metrics.prometheus.enabled", must_exist=True, is_type_of=bool),
|
||||
]
|
||||
|
||||
webserver_validators = [
|
||||
Validator("webserver.host", must_exist=True, is_type_of=str),
|
||||
Validator("webserver.port", must_exist=True, is_type_of=int),
|
||||
]
|
||||
|
||||
tracing_validators = [
|
||||
Validator("tracing.enabled", must_exist=True, is_type_of=bool),
|
||||
Validator("tracing.type", must_exist=True, is_type_of=str)
|
||||
]
|
||||
|
||||
opentelemetry_validators = [
|
||||
Validator("tracing.opentelemetry.endpoint", must_exist=True, is_type_of=str),
|
||||
Validator("tracing.opentelemetry.service_name", must_exist=True, is_type_of=str),
|
||||
Validator("tracing.opentelemetry.exporter", must_exist=True, is_type_of=str)
|
||||
]
|
||||
@ -1,169 +0,0 @@
|
||||
import asyncio
|
||||
import signal
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from aiormq.exceptions import AMQPConnectionError
|
||||
from dynaconf import Dynaconf
|
||||
from fastapi import FastAPI
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.config.loader import get_pyinfra_validators, validate_settings
|
||||
from pyinfra.queue.async_manager import AsyncQueueManager, RabbitMQConfig
|
||||
from pyinfra.queue.callback import Callback
|
||||
from pyinfra.queue.manager import QueueManager
|
||||
from pyinfra.utils.opentelemetry import instrument_app, instrument_pika, setup_trace
|
||||
from pyinfra.webserver.prometheus import (
|
||||
add_prometheus_endpoint,
|
||||
make_prometheus_processing_time_decorator_from_settings,
|
||||
)
|
||||
from pyinfra.webserver.utils import (
|
||||
add_health_check_endpoint,
|
||||
create_webserver_thread_from_settings,
|
||||
run_async_webserver,
|
||||
)
|
||||
|
||||
shutdown_flag = False
|
||||
|
||||
|
||||
async def graceful_shutdown(manager: AsyncQueueManager, queue_task, webserver_task):
|
||||
global shutdown_flag
|
||||
shutdown_flag = True
|
||||
logger.info("SIGTERM received, shutting down gracefully...")
|
||||
|
||||
if queue_task and not queue_task.done():
|
||||
queue_task.cancel()
|
||||
|
||||
# await queue manager shutdown
|
||||
await asyncio.gather(queue_task, manager.shutdown(), return_exceptions=True)
|
||||
|
||||
if webserver_task and not webserver_task.done():
|
||||
webserver_task.cancel()
|
||||
|
||||
# await webserver shutdown
|
||||
await asyncio.gather(webserver_task, return_exceptions=True)
|
||||
|
||||
logger.info("Shutdown complete.")
|
||||
|
||||
|
||||
async def run_async_queues(manager: AsyncQueueManager, app, port, host):
|
||||
"""Run the async webserver and the async queue manager concurrently."""
|
||||
queue_task = None
|
||||
webserver_task = None
|
||||
tenant_api_available = True
|
||||
|
||||
# add signal handler for SIGTERM and SIGINT
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.add_signal_handler(
|
||||
signal.SIGTERM, lambda: asyncio.create_task(graceful_shutdown(manager, queue_task, webserver_task))
|
||||
)
|
||||
loop.add_signal_handler(
|
||||
signal.SIGINT, lambda: asyncio.create_task(graceful_shutdown(manager, queue_task, webserver_task))
|
||||
)
|
||||
|
||||
try:
|
||||
active_tenants = await manager.fetch_active_tenants()
|
||||
|
||||
queue_task = asyncio.create_task(manager.run(active_tenants=active_tenants), name="queues")
|
||||
webserver_task = asyncio.create_task(run_async_webserver(app, port, host), name="webserver")
|
||||
await asyncio.gather(queue_task, webserver_task)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Main task was cancelled, initiating shutdown.")
|
||||
except AMQPConnectionError as e:
|
||||
logger.warning(f"AMQPConnectionError: {e} - shutting down.")
|
||||
except (aiohttp.ClientResponseError, aiohttp.ClientConnectorError):
|
||||
logger.warning("Tenant server did not answer - shutting down.")
|
||||
tenant_api_available = False
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while running async queues: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
finally:
|
||||
if shutdown_flag:
|
||||
logger.debug("Graceful shutdown already in progress.")
|
||||
else:
|
||||
logger.warning("Initiating shutdown due to error or manual interruption.")
|
||||
if not tenant_api_available:
|
||||
sys.exit(0)
|
||||
if queue_task and not queue_task.done():
|
||||
queue_task.cancel()
|
||||
|
||||
if webserver_task and not webserver_task.done():
|
||||
webserver_task.cancel()
|
||||
|
||||
await asyncio.gather(queue_task, manager.shutdown(), webserver_task, return_exceptions=True)
|
||||
logger.info("Shutdown complete.")
|
||||
|
||||
|
||||
def start_standard_queue_consumer(
|
||||
callback: Callback,
|
||||
settings: Dynaconf,
|
||||
app: FastAPI = None,
|
||||
):
|
||||
"""Default serving logic for research services.
|
||||
|
||||
Supplies /health, /ready and /prometheus endpoints (if enabled). The callback is monitored for processing time per
|
||||
message. Also traces the queue messages via openTelemetry (if enabled).
|
||||
Workload is received via queue messages and processed by the callback function (see pyinfra.queue.callback for
|
||||
callbacks).
|
||||
"""
|
||||
validate_settings(settings, get_pyinfra_validators())
|
||||
|
||||
logger.info("Starting webserver and queue consumer...")
|
||||
|
||||
app = app or FastAPI()
|
||||
|
||||
if settings.metrics.prometheus.enabled:
|
||||
logger.info("Prometheus metrics enabled.")
|
||||
app = add_prometheus_endpoint(app)
|
||||
callback = make_prometheus_processing_time_decorator_from_settings(settings)(callback)
|
||||
|
||||
if settings.tracing.enabled:
|
||||
setup_trace(settings)
|
||||
|
||||
instrument_pika(dynamic_queues=settings.dynamic_tenant_queues.enabled)
|
||||
instrument_app(app)
|
||||
|
||||
if settings.dynamic_tenant_queues.enabled:
|
||||
logger.info("Dynamic tenant queues enabled. Running async queues.")
|
||||
config = RabbitMQConfig(
|
||||
host=settings.rabbitmq.host,
|
||||
port=settings.rabbitmq.port,
|
||||
username=settings.rabbitmq.username,
|
||||
password=settings.rabbitmq.password,
|
||||
heartbeat=settings.rabbitmq.heartbeat,
|
||||
input_queue_prefix=settings.rabbitmq.service_request_queue_prefix,
|
||||
tenant_event_queue_suffix=settings.rabbitmq.tenant_event_queue_suffix,
|
||||
tenant_exchange_name=settings.rabbitmq.tenant_exchange_name,
|
||||
service_request_exchange_name=settings.rabbitmq.service_request_exchange_name,
|
||||
service_response_exchange_name=settings.rabbitmq.service_response_exchange_name,
|
||||
service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name,
|
||||
queue_expiration_time=settings.rabbitmq.queue_expiration_time,
|
||||
pod_name=settings.kubernetes.pod_name,
|
||||
)
|
||||
manager = AsyncQueueManager(
|
||||
config=config,
|
||||
tenant_service_url=settings.storage.tenant_server.endpoint,
|
||||
message_processor=callback,
|
||||
max_concurrent_tasks=(
|
||||
settings.asyncio.max_concurrent_tasks if hasattr(settings.asyncio, "max_concurrent_tasks") else 10
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.info("Dynamic tenant queues disabled. Running sync queues.")
|
||||
manager = QueueManager(settings)
|
||||
|
||||
app = add_health_check_endpoint(app, manager.is_ready)
|
||||
|
||||
if isinstance(manager, AsyncQueueManager):
|
||||
asyncio.run(run_async_queues(manager, app, port=settings.webserver.port, host=settings.webserver.host))
|
||||
|
||||
elif isinstance(manager, QueueManager):
|
||||
webserver = create_webserver_thread_from_settings(app, settings)
|
||||
webserver.start()
|
||||
try:
|
||||
manager.start_consuming(callback)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while consuming messages: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.warning(f"Behavior for type {type(manager)} is not defined")
|
||||
3
pyinfra/k8s_probes/__init__.py
Normal file
3
pyinfra/k8s_probes/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from pyinfra.k8s_probes import startup
|
||||
|
||||
__all__ = ["startup"]
|
||||
36
pyinfra/k8s_probes/startup.py
Normal file
36
pyinfra/k8s_probes/startup.py
Normal file
@ -0,0 +1,36 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
from pyinfra.queue.queue_manager import token_file_name
|
||||
|
||||
|
||||
def check_token_file():
|
||||
"""
|
||||
Checks if the token file of the QueueManager exists and is not empty, i.e. the queue manager has been started.
|
||||
|
||||
NOTE: This function suppresses all Exception's.
|
||||
|
||||
Returns True if the queue manager has been started, False otherwise
|
||||
"""
|
||||
|
||||
try:
|
||||
token_file_path = Path(token_file_name())
|
||||
|
||||
if token_file_path.exists():
|
||||
with token_file_path.open(mode="r", encoding="utf8") as token_file:
|
||||
contents = token_file.read().strip()
|
||||
|
||||
return contents != ""
|
||||
# We're intentionally do not handle exception here, since we're only using this in a short script.
|
||||
# Take care to expand this if the intended use changes
|
||||
except Exception:
|
||||
logging.getLogger(__file__).info("Caught exception when reading from token file", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if check_token_file():
|
||||
sys.exit(0)
|
||||
else:
|
||||
sys.exit(1)
|
||||
@ -0,0 +1,3 @@
|
||||
from pyinfra.queue import queue_manager
|
||||
|
||||
__all__ = ["queue_manager"]
|
||||
@ -1,329 +0,0 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, Set
|
||||
|
||||
import aiohttp
|
||||
from aio_pika import ExchangeType, IncomingMessage, Message, connect
|
||||
from aio_pika.abc import (
|
||||
AbstractChannel,
|
||||
AbstractConnection,
|
||||
AbstractExchange,
|
||||
AbstractIncomingMessage,
|
||||
AbstractQueue,
|
||||
)
|
||||
from aio_pika.exceptions import (
|
||||
ChannelClosed,
|
||||
ChannelInvalidStateError,
|
||||
ConnectionClosed,
|
||||
)
|
||||
from aiormq.exceptions import AMQPConnectionError
|
||||
from kn_utils.logging import logger
|
||||
from kn_utils.retry import retry
|
||||
|
||||
|
||||
@dataclass
|
||||
class RabbitMQConfig:
|
||||
host: str
|
||||
port: int
|
||||
username: str
|
||||
password: str
|
||||
heartbeat: int
|
||||
input_queue_prefix: str
|
||||
tenant_event_queue_suffix: str
|
||||
tenant_exchange_name: str
|
||||
service_request_exchange_name: str
|
||||
service_response_exchange_name: str
|
||||
service_dead_letter_queue_name: str
|
||||
queue_expiration_time: int
|
||||
pod_name: str
|
||||
|
||||
connection_params: Dict[str, object] = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.connection_params = {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"login": self.username,
|
||||
"password": self.password,
|
||||
"client_properties": {"heartbeat": self.heartbeat},
|
||||
}
|
||||
|
||||
|
||||
class AsyncQueueManager:
|
||||
def __init__(
|
||||
self,
|
||||
config: RabbitMQConfig,
|
||||
tenant_service_url: str,
|
||||
message_processor: Callable[[Dict[str, Any]], Dict[str, Any]],
|
||||
max_concurrent_tasks: int = 10,
|
||||
):
|
||||
self.config = config
|
||||
self.tenant_service_url = tenant_service_url
|
||||
self.message_processor = message_processor
|
||||
self.semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
||||
|
||||
self.connection: AbstractConnection | None = None
|
||||
self.channel: AbstractChannel | None = None
|
||||
self.tenant_exchange: AbstractExchange | None = None
|
||||
self.input_exchange: AbstractExchange | None = None
|
||||
self.output_exchange: AbstractExchange | None = None
|
||||
self.tenant_exchange_queue: AbstractQueue | None = None
|
||||
self.tenant_queues: Dict[str, AbstractChannel] = {}
|
||||
self.consumer_tags: Dict[str, str] = {}
|
||||
|
||||
self.message_count: int = 0
|
||||
|
||||
@retry(tries=5, exceptions=AMQPConnectionError, reraise=True, logger=logger)
|
||||
async def connect(self) -> None:
|
||||
logger.info("Attempting to connect to RabbitMQ...")
|
||||
self.connection = await connect(**self.config.connection_params)
|
||||
self.connection.close_callbacks.add(self.on_connection_close)
|
||||
self.channel = await self.connection.channel()
|
||||
await self.channel.set_qos(prefetch_count=1)
|
||||
logger.info("Successfully connected to RabbitMQ")
|
||||
|
||||
async def on_connection_close(self, sender, exc):
|
||||
"""This is a callback for unexpected connection closures."""
|
||||
logger.debug(f"Sender: {sender}")
|
||||
if isinstance(exc, ConnectionClosed):
|
||||
logger.warning("Connection to RabbitMQ lost. Attempting to reconnect...")
|
||||
try:
|
||||
active_tenants = await self.fetch_active_tenants()
|
||||
await self.run(active_tenants=active_tenants)
|
||||
logger.debug("Reconnected to RabbitMQ successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconnect to RabbitMQ: {e}")
|
||||
# cancel queue manager and webserver to shutdown service
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
[task.cancel() for task in tasks if task.get_name() in ["queues", "webserver"]]
|
||||
else:
|
||||
logger.debug("Connection closed on purpose.")
|
||||
|
||||
async def is_ready(self) -> bool:
|
||||
if self.connection is None or self.connection.is_closed:
|
||||
try:
|
||||
await self.connect()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to RabbitMQ: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
@retry(tries=5, exceptions=(AMQPConnectionError, ChannelInvalidStateError), reraise=True, logger=logger)
|
||||
async def setup_exchanges(self) -> None:
|
||||
self.tenant_exchange = await self.channel.declare_exchange(
|
||||
self.config.tenant_exchange_name, ExchangeType.TOPIC, durable=True
|
||||
)
|
||||
self.input_exchange = await self.channel.declare_exchange(
|
||||
self.config.service_request_exchange_name, ExchangeType.DIRECT, durable=True
|
||||
)
|
||||
self.output_exchange = await self.channel.declare_exchange(
|
||||
self.config.service_response_exchange_name, ExchangeType.DIRECT, durable=True
|
||||
)
|
||||
|
||||
# we must declare DLQ to handle error messages
|
||||
self.dead_letter_queue = await self.channel.declare_queue(
|
||||
self.config.service_dead_letter_queue_name, durable=True
|
||||
)
|
||||
|
||||
@retry(tries=5, exceptions=(AMQPConnectionError, ChannelInvalidStateError), reraise=True, logger=logger)
|
||||
async def setup_tenant_queue(self) -> None:
|
||||
self.tenant_exchange_queue = await self.channel.declare_queue(
|
||||
f"{self.config.pod_name}_{self.config.tenant_event_queue_suffix}",
|
||||
durable=True,
|
||||
arguments={
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.config.service_dead_letter_queue_name,
|
||||
"x-expires": self.config.queue_expiration_time,
|
||||
},
|
||||
)
|
||||
await self.tenant_exchange_queue.bind(self.tenant_exchange, routing_key="tenant.*")
|
||||
self.consumer_tags["tenant_exchange_queue"] = await self.tenant_exchange_queue.consume(
|
||||
self.process_tenant_message
|
||||
)
|
||||
|
||||
async def process_tenant_message(self, message: AbstractIncomingMessage) -> None:
|
||||
try:
|
||||
async with message.process():
|
||||
message_body = json.loads(message.body.decode())
|
||||
logger.debug(f"Tenant message received: {message_body}")
|
||||
tenant_id = message_body["tenantId"]
|
||||
routing_key = message.routing_key
|
||||
|
||||
if routing_key == "tenant.created":
|
||||
await self.create_tenant_queues(tenant_id)
|
||||
elif routing_key == "tenant.delete":
|
||||
await self.delete_tenant_queues(tenant_id)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
|
||||
async def create_tenant_queues(self, tenant_id: str) -> None:
|
||||
queue_name = f"{self.config.input_queue_prefix}_{tenant_id}"
|
||||
logger.info(f"Declaring queue: {queue_name}")
|
||||
try:
|
||||
input_queue = await self.channel.declare_queue(
|
||||
queue_name,
|
||||
durable=True,
|
||||
arguments={
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.config.service_dead_letter_queue_name,
|
||||
},
|
||||
)
|
||||
await input_queue.bind(self.input_exchange, routing_key=tenant_id)
|
||||
self.consumer_tags[tenant_id] = await input_queue.consume(self.process_input_message)
|
||||
self.tenant_queues[tenant_id] = input_queue
|
||||
logger.info(f"Created and started consuming queue for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
|
||||
async def delete_tenant_queues(self, tenant_id: str) -> None:
|
||||
if tenant_id in self.tenant_queues:
|
||||
# somehow queue.delete() does not work here
|
||||
await self.channel.queue_delete(f"{self.config.input_queue_prefix}_{tenant_id}")
|
||||
del self.tenant_queues[tenant_id]
|
||||
del self.consumer_tags[tenant_id]
|
||||
logger.info(f"Deleted queues for tenant {tenant_id}")
|
||||
|
||||
async def process_input_message(self, message: IncomingMessage) -> None:
|
||||
async def process_message_body_and_await_result(unpacked_message_body):
|
||||
async with self.semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
||||
logger.info("Processing payload in a separate thread.")
|
||||
result = await loop.run_in_executor(
|
||||
thread_pool_executor, self.message_processor, unpacked_message_body
|
||||
)
|
||||
return result
|
||||
|
||||
async with message.process(ignore_processed=True):
|
||||
if message.redelivered:
|
||||
logger.warning(f"Declining message with {message.delivery_tag=} due to it being redelivered.")
|
||||
await message.nack(requeue=False)
|
||||
return
|
||||
|
||||
if message.body.decode("utf-8") == "STOP":
|
||||
logger.info("Received stop signal, stopping consumption...")
|
||||
await message.ack()
|
||||
# TODO: shutdown is probably not the right call here - align w/ Dev what should happen on stop signal
|
||||
await self.shutdown()
|
||||
return
|
||||
|
||||
self.message_count += 1
|
||||
|
||||
try:
|
||||
tenant_id = message.routing_key
|
||||
|
||||
filtered_message_headers = (
|
||||
{k: v for k, v in message.headers.items() if k.lower().startswith("x-")} if message.headers else {}
|
||||
)
|
||||
|
||||
logger.debug(f"Processing message with {filtered_message_headers=}.")
|
||||
|
||||
result: dict = await (
|
||||
process_message_body_and_await_result({**json.loads(message.body), **filtered_message_headers})
|
||||
or {}
|
||||
)
|
||||
|
||||
if result:
|
||||
await self.publish_to_output_exchange(tenant_id, result, filtered_message_headers)
|
||||
await message.ack()
|
||||
logger.debug(f"Message with {message.delivery_tag=} acknowledged.")
|
||||
else:
|
||||
raise ValueError(f"Could not process message with {message.body=}.")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
await message.nack(requeue=False)
|
||||
logger.error(f"Invalid JSON in input message: {message.body}", exc_info=True)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(f"{e}, declining message with {message.delivery_tag=}.", exc_info=True)
|
||||
await message.nack(requeue=False)
|
||||
except Exception as e:
|
||||
await message.nack(requeue=False)
|
||||
logger.error(f"Error processing input message: {e}", exc_info=True)
|
||||
finally:
|
||||
self.message_count -= 1
|
||||
|
||||
async def publish_to_output_exchange(self, tenant_id: str, result: Dict[str, Any], headers: Dict[str, Any]) -> None:
|
||||
await self.output_exchange.publish(
|
||||
Message(body=json.dumps(result).encode(), headers=headers),
|
||||
routing_key=tenant_id,
|
||||
)
|
||||
logger.info(f"Published result to queue {tenant_id}.")
|
||||
|
||||
@retry(tries=5, exceptions=(aiohttp.ClientResponseError, aiohttp.ClientConnectorError), reraise=True, logger=logger)
|
||||
async def fetch_active_tenants(self) -> Set[str]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self.tenant_service_url) as response:
|
||||
response.raise_for_status()
|
||||
if response.headers["content-type"].lower() == "application/json":
|
||||
data = await response.json()
|
||||
return {tenant["tenantId"] for tenant in data}
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to fetch active tenants. Content type is not JSON: {response.headers['content-type'].lower()}"
|
||||
)
|
||||
return set()
|
||||
|
||||
@retry(
|
||||
tries=5,
|
||||
exceptions=(
|
||||
AMQPConnectionError,
|
||||
ChannelInvalidStateError,
|
||||
),
|
||||
reraise=True,
|
||||
logger=logger,
|
||||
)
|
||||
async def initialize_tenant_queues(self, active_tenants: set) -> None:
|
||||
for tenant_id in active_tenants:
|
||||
await self.create_tenant_queues(tenant_id)
|
||||
|
||||
async def run(self, active_tenants: set) -> None:
|
||||
|
||||
await self.connect()
|
||||
await self.setup_exchanges()
|
||||
await self.initialize_tenant_queues(active_tenants=active_tenants)
|
||||
await self.setup_tenant_queue()
|
||||
|
||||
logger.info("RabbitMQ handler is running. Press CTRL+C to exit.")
|
||||
|
||||
async def close_channels(self) -> None:
|
||||
try:
|
||||
if self.channel and not self.channel.is_closed:
|
||||
# Cancel queues to stop fetching messages
|
||||
logger.debug("Cancelling queues...")
|
||||
for tenant, queue in self.tenant_queues.items():
|
||||
await queue.cancel(self.consumer_tags[tenant])
|
||||
if self.tenant_exchange_queue:
|
||||
await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"])
|
||||
while self.message_count != 0:
|
||||
logger.debug(f"Messages are still being processed: {self.message_count=} ")
|
||||
await asyncio.sleep(2)
|
||||
await self.channel.close(exc=asyncio.CancelledError)
|
||||
logger.debug("Channel closed.")
|
||||
else:
|
||||
logger.debug("No channel to close.")
|
||||
except ChannelClosed:
|
||||
logger.warning("Channel was already closed.")
|
||||
except ConnectionClosed:
|
||||
logger.warning("Connection was lost, unable to close channel.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during channel shutdown: {e}")
|
||||
|
||||
async def close_connection(self) -> None:
|
||||
try:
|
||||
if self.connection and not self.connection.is_closed:
|
||||
await self.connection.close(exc=asyncio.CancelledError)
|
||||
logger.debug("Connection closed.")
|
||||
else:
|
||||
logger.debug("No connection to close.")
|
||||
except ConnectionClosed:
|
||||
logger.warning("Connection was already closed.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing connection: {e}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.info("Shutting down RabbitMQ handler...")
|
||||
await self.close_channels()
|
||||
await self.close_connection()
|
||||
logger.info("RabbitMQ handler shut down successfully.")
|
||||
@ -1,42 +0,0 @@
|
||||
from typing import Callable
|
||||
|
||||
from dynaconf import Dynaconf
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.storage.connection import get_storage
|
||||
from pyinfra.storage.utils import (
|
||||
download_data_bytes_as_specified_in_message,
|
||||
upload_data_as_specified_in_message,
|
||||
DownloadedData,
|
||||
)
|
||||
|
||||
DataProcessor = Callable[[dict[str, DownloadedData] | DownloadedData, dict], dict | list | str]
|
||||
Callback = Callable[[dict], dict]
|
||||
|
||||
|
||||
def make_download_process_upload_callback(data_processor: DataProcessor, settings: Dynaconf) -> Callback:
|
||||
"""Default callback for processing queue messages.
|
||||
|
||||
Data will be downloaded from the storage as specified in the message. If a tenant id is specified, the storage
|
||||
will be configured to use that tenant id, otherwise the storage is configured as specified in the settings.
|
||||
The data is the passed to the dataprocessor, together with the message. The dataprocessor should return a
|
||||
json serializable object. This object is then uploaded to the storage as specified in the message. The response
|
||||
message is just the original message.
|
||||
"""
|
||||
|
||||
def inner(queue_message_payload: dict) -> dict:
|
||||
logger.info(f"Processing payload with download-process-upload callback...")
|
||||
|
||||
storage = get_storage(settings, queue_message_payload.get("X-TENANT-ID"))
|
||||
|
||||
data: dict[str, DownloadedData] | DownloadedData = download_data_bytes_as_specified_in_message(
|
||||
storage, queue_message_payload
|
||||
)
|
||||
|
||||
result = data_processor(data, queue_message_payload)
|
||||
|
||||
upload_data_as_specified_in_message(storage, queue_message_payload, result)
|
||||
|
||||
return queue_message_payload
|
||||
|
||||
return inner
|
||||
@ -1,229 +0,0 @@
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from typing import Callable, Union
|
||||
|
||||
import pika
|
||||
import pika.exceptions
|
||||
from dynaconf import Dynaconf
|
||||
from kn_utils.logging import logger
|
||||
from kn_utils.retry import retry
|
||||
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
|
||||
|
||||
from pyinfra.config.loader import validate_settings
|
||||
from pyinfra.config.validators import queue_manager_validators
|
||||
|
||||
pika_logger = logging.getLogger("pika")
|
||||
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
|
||||
|
||||
MessageProcessor = Callable[[dict], dict]
|
||||
|
||||
|
||||
class QueueManager:
|
||||
def __init__(self, settings: Dynaconf):
|
||||
validate_settings(settings, queue_manager_validators)
|
||||
|
||||
self.input_queue = settings.rabbitmq.input_queue
|
||||
self.output_queue = settings.rabbitmq.output_queue
|
||||
self.dead_letter_queue = settings.rabbitmq.dead_letter_queue
|
||||
|
||||
self.connection_parameters = self.create_connection_parameters(settings)
|
||||
|
||||
self.connection: Union[BlockingConnection, None] = None
|
||||
self.channel: Union[BlockingChannel, None] = None
|
||||
self.connection_sleep = settings.rabbitmq.connection_sleep
|
||||
self.processing_callback = False
|
||||
self.received_signal = False
|
||||
|
||||
atexit.register(self.stop_consuming)
|
||||
signal.signal(signal.SIGTERM, self._handle_stop_signal)
|
||||
signal.signal(signal.SIGINT, self._handle_stop_signal)
|
||||
|
||||
self.max_retries = settings.rabbitmq.max_retries or 5
|
||||
self.max_delay = settings.rabbitmq.max_delay or 60
|
||||
|
||||
@staticmethod
|
||||
def create_connection_parameters(settings: Dynaconf):
|
||||
credentials = pika.PlainCredentials(username=settings.rabbitmq.username, password=settings.rabbitmq.password)
|
||||
pika_connection_params = {
|
||||
"host": settings.rabbitmq.host,
|
||||
"port": settings.rabbitmq.port,
|
||||
"credentials": credentials,
|
||||
"heartbeat": settings.rabbitmq.heartbeat,
|
||||
}
|
||||
|
||||
return pika.ConnectionParameters(**pika_connection_params)
|
||||
|
||||
@retry(
|
||||
tries=5,
|
||||
exceptions=(pika.exceptions.AMQPConnectionError, pika.exceptions.ChannelClosedByBroker),
|
||||
reraise=True,
|
||||
)
|
||||
def establish_connection(self):
|
||||
if self.connection and self.connection.is_open:
|
||||
logger.debug("Connection to RabbitMQ already established.")
|
||||
return
|
||||
|
||||
logger.info("Establishing connection to RabbitMQ...")
|
||||
self.connection = pika.BlockingConnection(parameters=self.connection_parameters)
|
||||
|
||||
logger.debug("Opening channel...")
|
||||
self.channel = self.connection.channel()
|
||||
self.channel.basic_qos(prefetch_count=1)
|
||||
|
||||
args = {
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.dead_letter_queue,
|
||||
}
|
||||
|
||||
self.channel.queue_declare(self.input_queue, arguments=args, auto_delete=False, durable=True)
|
||||
self.channel.queue_declare(self.output_queue, arguments=args, auto_delete=False, durable=True)
|
||||
|
||||
logger.info("Connection to RabbitMQ established, channel open.")
|
||||
|
||||
def is_ready(self):
|
||||
try:
|
||||
self.establish_connection()
|
||||
return self.channel.is_open
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to establish connection: {e}")
|
||||
return False
|
||||
|
||||
@retry(
|
||||
tries=5,
|
||||
exceptions=pika.exceptions.AMQPConnectionError,
|
||||
reraise=True,
|
||||
)
|
||||
def start_consuming(self, message_processor: Callable):
|
||||
on_message_callback = self._make_on_message_callback(message_processor)
|
||||
|
||||
try:
|
||||
self.establish_connection()
|
||||
self.channel.basic_consume(self.input_queue, on_message_callback)
|
||||
logger.info("Starting to consume messages...")
|
||||
self.channel.start_consuming()
|
||||
except pika.exceptions.AMQPConnectionError as e:
|
||||
logger.error(f"AMQP Connection Error: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"An unexpected error occurred while consuming messages: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
self.stop_consuming()
|
||||
|
||||
def stop_consuming(self):
|
||||
if self.channel and self.channel.is_open:
|
||||
logger.info("Stopping consuming...")
|
||||
self.channel.stop_consuming()
|
||||
logger.info("Closing channel...")
|
||||
self.channel.close()
|
||||
|
||||
if self.connection and self.connection.is_open:
|
||||
logger.info("Closing connection to RabbitMQ...")
|
||||
self.connection.close()
|
||||
|
||||
def publish_message_to_input_queue(self, message: Union[str, bytes, dict], properties: pika.BasicProperties = None):
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
elif isinstance(message, dict):
|
||||
message = json.dumps(message).encode("utf-8")
|
||||
|
||||
self.establish_connection()
|
||||
self.channel.basic_publish(
|
||||
"",
|
||||
self.input_queue,
|
||||
properties=properties,
|
||||
body=message,
|
||||
)
|
||||
logger.info(f"Published message to queue {self.input_queue}.")
|
||||
|
||||
def purge_queues(self):
|
||||
self.establish_connection()
|
||||
try:
|
||||
self.channel.queue_purge(self.input_queue)
|
||||
self.channel.queue_purge(self.output_queue)
|
||||
logger.info("Queues purged.")
|
||||
except pika.exceptions.ChannelWrongStateError:
|
||||
pass
|
||||
|
||||
def get_message_from_output_queue(self):
|
||||
self.establish_connection()
|
||||
return self.channel.basic_get(self.output_queue, auto_ack=True)
|
||||
|
||||
def _make_on_message_callback(self, message_processor: MessageProcessor):
|
||||
def process_message_body_and_await_result(unpacked_message_body):
|
||||
# Processing the message in a separate thread is necessary for the main thread pika client to be able to
|
||||
# process data events (e.g. heartbeats) while the message is being processed.
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
||||
logger.info("Processing payload in separate thread.")
|
||||
future = thread_pool_executor.submit(message_processor, unpacked_message_body)
|
||||
|
||||
# TODO: This block is probably not necessary, but kept since the implications of removing it are
|
||||
# unclear. Remove it in a future iteration where less changes are being made to the code base.
|
||||
while future.running():
|
||||
logger.debug("Waiting for payload processing to finish...")
|
||||
self.connection.sleep(self.connection_sleep)
|
||||
|
||||
return future.result()
|
||||
|
||||
def on_message_callback(channel, method, properties, body):
|
||||
logger.info(f"Received message from queue with delivery_tag {method.delivery_tag}.")
|
||||
self.processing_callback = True
|
||||
|
||||
if method.redelivered:
|
||||
logger.warning(f"Declining message with {method.delivery_tag=} due to it being redelivered.")
|
||||
channel.basic_nack(method.delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
if body.decode("utf-8") == "STOP":
|
||||
logger.info(f"Received stop signal, stopping consuming...")
|
||||
channel.basic_ack(delivery_tag=method.delivery_tag)
|
||||
self.stop_consuming()
|
||||
return
|
||||
|
||||
try:
|
||||
filtered_message_headers = (
|
||||
{k: v for k, v in properties.headers.items() if k.lower().startswith("x-")}
|
||||
if properties.headers
|
||||
else {}
|
||||
)
|
||||
logger.debug(f"Processing message with {filtered_message_headers=}.")
|
||||
result: dict = (
|
||||
process_message_body_and_await_result({**json.loads(body), **filtered_message_headers}) or {}
|
||||
)
|
||||
|
||||
channel.basic_publish(
|
||||
"",
|
||||
self.output_queue,
|
||||
json.dumps(result).encode(),
|
||||
properties=pika.BasicProperties(headers=filtered_message_headers),
|
||||
)
|
||||
logger.info(f"Published result to queue {self.output_queue}.")
|
||||
|
||||
channel.basic_ack(delivery_tag=method.delivery_tag)
|
||||
logger.debug(f"Message with {method.delivery_tag=} acknowledged.")
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(f"{e}, declining message with {method.delivery_tag=}.")
|
||||
channel.basic_nack(method.delivery_tag, requeue=False)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to process message with {method.delivery_tag=}, declining...", exc_info=True)
|
||||
channel.basic_nack(method.delivery_tag, requeue=False)
|
||||
raise
|
||||
|
||||
finally:
|
||||
self.processing_callback = False
|
||||
if self.received_signal:
|
||||
self.stop_consuming()
|
||||
sys.exit(0)
|
||||
|
||||
return on_message_callback
|
||||
|
||||
def _handle_stop_signal(self, signum, *args, **kwargs):
|
||||
logger.info(f"Received signal {signum}, stopping consuming...")
|
||||
self.received_signal = True
|
||||
if not self.processing_callback:
|
||||
self.stop_consuming()
|
||||
sys.exit(0)
|
||||
174
pyinfra/queue/queue_manager.py
Normal file
174
pyinfra/queue/queue_manager.py
Normal file
@ -0,0 +1,174 @@
|
||||
import atexit
|
||||
import json
|
||||
import logging
|
||||
import signal
|
||||
from typing import Callable
|
||||
from pathlib import Path
|
||||
import concurrent.futures
|
||||
|
||||
import pika
|
||||
import pika.exceptions
|
||||
|
||||
from pyinfra.config import Config
|
||||
|
||||
pika_logger = logging.getLogger("pika")
|
||||
pika_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def get_connection_params(config: Config) -> pika.ConnectionParameters:
|
||||
credentials = pika.PlainCredentials(username=config.rabbitmq_username, password=config.rabbitmq_password)
|
||||
pika_connection_params = {
|
||||
"host": config.rabbitmq_host,
|
||||
"port": config.rabbitmq_port,
|
||||
"credentials": credentials,
|
||||
"heartbeat": int(config.rabbitmq_heartbeat),
|
||||
}
|
||||
|
||||
return pika.ConnectionParameters(**pika_connection_params)
|
||||
|
||||
|
||||
def _get_n_previous_attempts(props):
|
||||
return 0 if props.headers is None else props.headers.get("x-retry-count", 0)
|
||||
|
||||
|
||||
def token_file_name():
|
||||
token_file_path = Path("/tmp") / "consumer_token.txt"
|
||||
return token_file_path
|
||||
|
||||
|
||||
class QueueManager(object):
|
||||
def __init__(self, config: Config):
|
||||
self.logger = logging.getLogger("queue_manager")
|
||||
self.logger.setLevel(config.logging_level_root)
|
||||
|
||||
self._write_token = config.write_consumer_token == "True"
|
||||
|
||||
self._set_consumer_token(None)
|
||||
|
||||
self._connection_params = get_connection_params(config)
|
||||
|
||||
# controls for how long we only process data events (e.g. heartbeats),
|
||||
# while the queue is blocked and we process the given callback function
|
||||
self._connection_sleep = config.rabbitmq_connection_sleep
|
||||
|
||||
self._input_queue = config.request_queue
|
||||
self._output_queue = config.response_queue
|
||||
self._dead_letter_queue = config.dead_letter_queue
|
||||
|
||||
atexit.register(self.stop_consuming)
|
||||
signal.signal(signal.SIGTERM, self._handle_stop_signal)
|
||||
signal.signal(signal.SIGINT, self._handle_stop_signal)
|
||||
|
||||
def _set_consumer_token(self, token_value):
|
||||
self._consumer_token = token_value
|
||||
|
||||
if self._write_token:
|
||||
token_file_path = token_file_name()
|
||||
|
||||
with token_file_path.open(mode="w", encoding="utf8") as token_file:
|
||||
text = token_value if token_value is not None else ""
|
||||
token_file.write(text)
|
||||
|
||||
def _open_channel(self):
|
||||
self._connection = pika.BlockingConnection(parameters=self._connection_params)
|
||||
self._channel = self._connection.channel()
|
||||
self._channel.basic_qos(prefetch_count=1)
|
||||
|
||||
args = {"x-dead-letter-exchange": "", "x-dead-letter-routing-key": self._dead_letter_queue}
|
||||
|
||||
self._channel.queue_declare(self._input_queue, arguments=args, auto_delete=False, durable=True)
|
||||
self._channel.queue_declare(self._output_queue, arguments=args, auto_delete=False, durable=True)
|
||||
|
||||
def _close_channel(self):
|
||||
self._channel.close()
|
||||
self._connection.close()
|
||||
|
||||
def start_consuming(self, process_message_callback: Callable):
|
||||
callback = self._create_queue_callback(process_message_callback)
|
||||
|
||||
self._set_consumer_token(None)
|
||||
|
||||
self.logger.info("Consuming from queue")
|
||||
try:
|
||||
self._open_channel()
|
||||
|
||||
self._set_consumer_token(self._channel.basic_consume(self._input_queue, callback))
|
||||
self.logger.info(f"Registered with consumer-tag: {self._consumer_token}")
|
||||
self._channel.start_consuming()
|
||||
except Exception:
|
||||
self.logger.warning(
|
||||
"An unexpected exception occurred while consuming messages. Consuming will stop."
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
self.stop_consuming()
|
||||
self._close_channel()
|
||||
|
||||
def stop_consuming(self):
|
||||
if self._consumer_token and self._connection:
|
||||
self.logger.info(f"Cancelling subscription for consumer-tag: {self._consumer_token}")
|
||||
self._channel.stop_consuming(self._consumer_token)
|
||||
self._set_consumer_token(None)
|
||||
|
||||
def _handle_stop_signal(self, signal_number, _stack_frame, *args, **kwargs):
|
||||
self.logger.info(f"Received signal {signal_number}")
|
||||
self.stop_consuming()
|
||||
|
||||
def _create_queue_callback(self, process_message_callback: Callable):
|
||||
def process_message_body_and_await_result(unpacked_message_body):
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
||||
self.logger.debug("opening thread for callback")
|
||||
future = thread_pool_executor.submit(process_message_callback, unpacked_message_body)
|
||||
|
||||
while future.running():
|
||||
self.logger.debug("callback running in thread, processing data events in the meantime")
|
||||
self._connection.sleep(float(self._connection_sleep))
|
||||
|
||||
self.logger.debug("fetching result from callback")
|
||||
return future.result()
|
||||
|
||||
def callback(_channel, frame, properties, body):
|
||||
self.logger.info(f"Received message from queue with delivery_tag {frame.delivery_tag}")
|
||||
|
||||
# Only try to process each message once.
|
||||
# Requeueing will be handled by the dead-letter-exchange.
|
||||
# This prevents endless retries on messages that are impossible to process.
|
||||
if frame.redelivered:
|
||||
self.logger.info(f"Aborting message processing for delivery_tag {frame.delivery_tag} "
|
||||
f"due to it being redelivered")
|
||||
self._channel.basic_nack(frame.delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
self.logger.debug(f"Processing {(frame, properties, body)}.")
|
||||
try:
|
||||
unpacked_message_body = json.loads(body)
|
||||
|
||||
should_publish_result, callback_result = process_message_body_and_await_result(unpacked_message_body)
|
||||
|
||||
if should_publish_result:
|
||||
self.logger.info(f"Processed message with delivery_tag {frame.delivery_tag}, "
|
||||
f"publishing result to result-queue")
|
||||
self._channel.basic_publish("", self._output_queue, json.dumps(callback_result).encode())
|
||||
|
||||
self.logger.info(
|
||||
f"Result published, acknowledging incoming message with delivery_tag {frame.delivery_tag}"
|
||||
)
|
||||
self._channel.basic_ack(frame.delivery_tag)
|
||||
else:
|
||||
self.logger.info(f"Processed message with delivery_tag {frame.delivery_tag}, declining message")
|
||||
self._channel.basic_nack(frame.delivery_tag, requeue=False)
|
||||
|
||||
except Exception as ex:
|
||||
n_attempts = _get_n_previous_attempts(properties) + 1
|
||||
self.logger.warning(f"Failed to process message, {n_attempts} attempts, error: {str(ex)}")
|
||||
self._channel.basic_nack(frame.delivery_tag, requeue=False)
|
||||
raise ex
|
||||
|
||||
return callback
|
||||
|
||||
def clear(self):
|
||||
try:
|
||||
self._channel.queue_purge(self._input_queue)
|
||||
self._channel.queue_purge(self._output_queue)
|
||||
except pika.exceptions.ChannelWrongStateError:
|
||||
pass
|
||||
@ -0,0 +1,4 @@
|
||||
from pyinfra.storage import adapters, storage
|
||||
from pyinfra.storage.storage import get_storage
|
||||
|
||||
__all__ = ["adapters", "storage"]
|
||||
80
pyinfra/storage/adapters/azure.py
Normal file
80
pyinfra/storage/adapters/azure.py
Normal file
@ -0,0 +1,80 @@
|
||||
import logging
|
||||
from itertools import repeat
|
||||
from operator import attrgetter
|
||||
|
||||
from azure.storage.blob import ContainerClient, BlobServiceClient
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.config import Config, get_config
|
||||
|
||||
CONFIG = get_config()
|
||||
logger = logging.getLogger(CONFIG.logging_level_root)
|
||||
|
||||
logging.getLogger("azure").setLevel(logging.WARNING)
|
||||
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class AzureStorageAdapter(object):
|
||||
def __init__(self, client: BlobServiceClient):
|
||||
self._client: BlobServiceClient = client
|
||||
|
||||
def has_bucket(self, bucket_name):
|
||||
container_client = self._client.get_container_client(bucket_name)
|
||||
return container_client.exists()
|
||||
|
||||
def make_bucket(self, bucket_name):
|
||||
container_client = self._client.get_container_client(bucket_name)
|
||||
container_client if container_client.exists() else self._client.create_container(bucket_name)
|
||||
|
||||
def __provide_container_client(self, bucket_name) -> ContainerClient:
|
||||
self.make_bucket(bucket_name)
|
||||
container_client = self._client.get_container_client(bucket_name)
|
||||
return container_client
|
||||
|
||||
def put_object(self, bucket_name, object_name, data):
|
||||
logger.debug(f"Uploading '{object_name}'...")
|
||||
container_client = self.__provide_container_client(bucket_name)
|
||||
blob_client = container_client.get_blob_client(object_name)
|
||||
blob_client.upload_blob(data, overwrite=True)
|
||||
|
||||
def exists(self, bucket_name, object_name):
|
||||
container_client = self.__provide_container_client(bucket_name)
|
||||
blob_client = container_client.get_blob_client(object_name)
|
||||
return blob_client.exists()
|
||||
|
||||
@retry(tries=3, delay=5, jitter=(1, 3))
|
||||
def get_object(self, bucket_name, object_name):
|
||||
logger.debug(f"Downloading '{object_name}'...")
|
||||
|
||||
try:
|
||||
container_client = self.__provide_container_client(bucket_name)
|
||||
blob_client = container_client.get_blob_client(object_name)
|
||||
blob_data = blob_client.download_blob()
|
||||
return blob_data.readall()
|
||||
except Exception as err:
|
||||
raise Exception("Failed getting object from azure client") from err
|
||||
|
||||
def get_all_objects(self, bucket_name):
|
||||
container_client = self.__provide_container_client(bucket_name)
|
||||
blobs = container_client.list_blobs()
|
||||
for blob in blobs:
|
||||
logger.debug(f"Downloading '{blob.name}'...")
|
||||
blob_client = container_client.get_blob_client(blob)
|
||||
blob_data = blob_client.download_blob()
|
||||
data = blob_data.readall()
|
||||
yield data
|
||||
|
||||
def clear_bucket(self, bucket_name):
|
||||
logger.debug(f"Clearing Azure container '{bucket_name}'...")
|
||||
container_client = self._client.get_container_client(bucket_name)
|
||||
blobs = container_client.list_blobs()
|
||||
container_client.delete_blobs(*blobs)
|
||||
|
||||
def get_all_object_names(self, bucket_name):
|
||||
container_client = self.__provide_container_client(bucket_name)
|
||||
blobs = container_client.list_blobs()
|
||||
return zip(repeat(bucket_name), map(attrgetter("name"), blobs))
|
||||
|
||||
|
||||
def get_azure_storage(config: Config):
|
||||
return AzureStorageAdapter(BlobServiceClient.from_connection_string(conn_str=config.storage_azureconnectionstring))
|
||||
100
pyinfra/storage/adapters/s3.py
Normal file
100
pyinfra/storage/adapters/s3.py
Normal file
@ -0,0 +1,100 @@
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
from itertools import repeat
|
||||
from operator import attrgetter
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from minio import Minio
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.config import Config, get_config
|
||||
|
||||
|
||||
CONFIG = get_config()
|
||||
logger = logging.getLogger(CONFIG.logging_level_root)
|
||||
|
||||
ALLOWED_CONNECTION_SCHEMES = {"http", "https"}
|
||||
URL_VALIDATOR = re.compile(
|
||||
r"^(("
|
||||
+ r"([A-Za-z]{3,9}:(?:\/\/)?)"
|
||||
+ r"(?:[\-;:&=\+\$,\w]+@)?"
|
||||
+ r"[A-Za-z0-9\.\-]+|(?:www\.|[\-;:&=\+\$,\w]+@)"
|
||||
+ r"[A-Za-z0-9\.\-]+)"
|
||||
+ r"((?:\/[\+~%\/\.\w\-_]*)?"
|
||||
+ r"\??(?:[\-\+=&;%@\.\w_]*)#?(?:[\.\!\/\\\w]*))?)"
|
||||
)
|
||||
|
||||
|
||||
class S3StorageAdapter(object):
|
||||
def __init__(self, client: Minio):
|
||||
self._client = client
|
||||
|
||||
def make_bucket(self, bucket_name):
|
||||
if not self.has_bucket(bucket_name):
|
||||
self._client.make_bucket(bucket_name)
|
||||
|
||||
def has_bucket(self, bucket_name):
|
||||
return self._client.bucket_exists(bucket_name)
|
||||
|
||||
def put_object(self, bucket_name, object_name, data):
|
||||
logger.debug(f"Uploading '{object_name}'...")
|
||||
data = io.BytesIO(data)
|
||||
self._client.put_object(bucket_name, object_name, data, length=data.getbuffer().nbytes)
|
||||
|
||||
def exists(self, bucket_name, object_name):
|
||||
try:
|
||||
self._client.stat_object(bucket_name, object_name)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@retry(tries=3, delay=5, jitter=(1, 3))
|
||||
def get_object(self, bucket_name, object_name):
|
||||
logger.debug(f"Downloading '{object_name}'...")
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = self._client.get_object(bucket_name, object_name)
|
||||
return response.data
|
||||
except Exception as err:
|
||||
raise Exception("Failed getting object from s3 client") from err
|
||||
finally:
|
||||
if response:
|
||||
response.close()
|
||||
response.release_conn()
|
||||
|
||||
def get_all_objects(self, bucket_name):
|
||||
for obj in self._client.list_objects(bucket_name, recursive=True):
|
||||
logger.debug(f"Downloading '{obj.object_name}'...")
|
||||
yield self.get_object(bucket_name, obj.object_name)
|
||||
|
||||
def clear_bucket(self, bucket_name):
|
||||
logger.debug(f"Clearing S3 bucket '{bucket_name}'...")
|
||||
objects = self._client.list_objects(bucket_name, recursive=True)
|
||||
for obj in objects:
|
||||
self._client.remove_object(bucket_name, obj.object_name)
|
||||
|
||||
def get_all_object_names(self, bucket_name):
|
||||
objs = self._client.list_objects(bucket_name, recursive=True)
|
||||
return zip(repeat(bucket_name), map(attrgetter("object_name"), objs))
|
||||
|
||||
|
||||
def _parse_endpoint(endpoint):
|
||||
parsed_url = urlparse(endpoint)
|
||||
if URL_VALIDATOR.match(endpoint) and parsed_url.netloc and parsed_url.scheme in ALLOWED_CONNECTION_SCHEMES:
|
||||
return {"secure": parsed_url.scheme == "https", "endpoint": parsed_url.netloc}
|
||||
else:
|
||||
raise Exception(f"The configured storage endpoint is not a valid url: {endpoint}")
|
||||
|
||||
|
||||
def get_s3_storage(config: Config):
|
||||
return S3StorageAdapter(
|
||||
Minio(
|
||||
**_parse_endpoint(config.storage_endpoint),
|
||||
access_key=config.storage_key,
|
||||
secret_key=config.storage_secret,
|
||||
# This is relevant for running on s3
|
||||
region=config.storage_region,
|
||||
)
|
||||
)
|
||||
@ -1,89 +0,0 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import requests
|
||||
from dynaconf import Dynaconf
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.config.loader import validate_settings
|
||||
from pyinfra.config.validators import (
|
||||
multi_tenant_storage_validators,
|
||||
storage_validators,
|
||||
)
|
||||
from pyinfra.storage.storages.azure import get_azure_storage_from_settings
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
|
||||
from pyinfra.storage.storages.storage import Storage
|
||||
from pyinfra.utils.cipher import decrypt
|
||||
|
||||
|
||||
def get_storage(settings: Dynaconf, tenant_id: str = None) -> Storage:
|
||||
"""Establishes a storage connection.
|
||||
If tenant_id is provided, gets storage connection information from tenant server. These connections are cached.
|
||||
Otherwise, gets storage connection information from settings.
|
||||
"""
|
||||
logger.info("Establishing storage connection...")
|
||||
|
||||
if tenant_id:
|
||||
logger.info(f"Using tenant storage for {tenant_id}.")
|
||||
validate_settings(settings, multi_tenant_storage_validators)
|
||||
|
||||
return get_storage_for_tenant(
|
||||
tenant_id,
|
||||
settings.storage.tenant_server.endpoint,
|
||||
settings.storage.tenant_server.public_key,
|
||||
)
|
||||
|
||||
logger.info("Using default storage.")
|
||||
validate_settings(settings, storage_validators)
|
||||
|
||||
return storage_dispatcher[settings.storage.backend](settings)
|
||||
|
||||
|
||||
storage_dispatcher = {
|
||||
"azure": get_azure_storage_from_settings,
|
||||
"s3": get_s3_storage_from_settings,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def get_storage_for_tenant(tenant: str, endpoint: str, public_key: str) -> Storage:
|
||||
response = requests.get(f"{endpoint}/{tenant}").json()
|
||||
|
||||
maybe_azure = response.get("azureStorageConnection")
|
||||
maybe_s3 = response.get("s3StorageConnection")
|
||||
|
||||
assert (maybe_azure or maybe_s3) and not (maybe_azure and maybe_s3), "Only one storage backend can be used."
|
||||
|
||||
if maybe_azure:
|
||||
connection_string = decrypt(public_key, maybe_azure["connectionString"])
|
||||
backend = "azure"
|
||||
storage_info = {
|
||||
"storage": {
|
||||
"azure": {
|
||||
"connection_string": connection_string,
|
||||
"container": maybe_azure["containerName"],
|
||||
},
|
||||
}
|
||||
}
|
||||
elif maybe_s3:
|
||||
secret = decrypt(public_key, maybe_s3["secret"])
|
||||
backend = "s3"
|
||||
storage_info = {
|
||||
"storage": {
|
||||
"s3": {
|
||||
"endpoint": maybe_s3["endpoint"],
|
||||
"key": maybe_s3["key"],
|
||||
"secret": secret,
|
||||
"region": maybe_s3["region"],
|
||||
"bucket": maybe_s3["bucketName"],
|
||||
},
|
||||
}
|
||||
}
|
||||
else:
|
||||
raise Exception(f"Unknown storage backend in {response}.")
|
||||
|
||||
storage_settings = Dynaconf()
|
||||
storage_settings.update(storage_info)
|
||||
|
||||
storage = storage_dispatcher[backend](storage_settings)
|
||||
|
||||
return storage
|
||||
17
pyinfra/storage/storage.py
Normal file
17
pyinfra/storage/storage.py
Normal file
@ -0,0 +1,17 @@
|
||||
import logging
|
||||
|
||||
from pyinfra.config import get_config, Config
|
||||
from pyinfra.storage.adapters.azure import get_azure_storage
|
||||
from pyinfra.storage.adapters.s3 import get_s3_storage
|
||||
|
||||
|
||||
def get_storage(config: Config):
|
||||
|
||||
if config.storage_backend == "s3":
|
||||
storage = get_s3_storage(config)
|
||||
elif config.storage_backend == "azure":
|
||||
storage = get_azure_storage(config)
|
||||
else:
|
||||
raise Exception(f"Unknown storage backend '{config.storage_backend}'.")
|
||||
|
||||
return storage
|
||||
@ -1,91 +0,0 @@
|
||||
import logging
|
||||
from itertools import repeat
|
||||
from operator import attrgetter
|
||||
|
||||
from azure.storage.blob import BlobServiceClient, ContainerClient
|
||||
from dynaconf import Dynaconf
|
||||
from kn_utils.logging import logger
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.config.loader import validate_settings
|
||||
from pyinfra.config.validators import azure_storage_validators
|
||||
from pyinfra.storage.storages.storage import Storage
|
||||
|
||||
logging.getLogger("azure").setLevel(logging.WARNING)
|
||||
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class AzureStorage(Storage):
|
||||
def __init__(self, client: BlobServiceClient, bucket: str):
|
||||
self._client: BlobServiceClient = client
|
||||
self._bucket = bucket
|
||||
|
||||
@property
|
||||
def bucket(self):
|
||||
return self._bucket
|
||||
|
||||
def has_bucket(self):
|
||||
container_client = self._client.get_container_client(self.bucket)
|
||||
return container_client.exists()
|
||||
|
||||
def make_bucket(self):
|
||||
container_client = self._client.get_container_client(self.bucket)
|
||||
container_client if container_client.exists() else self._client.create_container(self.bucket)
|
||||
|
||||
def __provide_container_client(self) -> ContainerClient:
|
||||
self.make_bucket()
|
||||
container_client = self._client.get_container_client(self.bucket)
|
||||
return container_client
|
||||
|
||||
def put_object(self, object_name, data):
|
||||
logger.debug(f"Uploading '{object_name}'...")
|
||||
container_client = self.__provide_container_client()
|
||||
blob_client = container_client.get_blob_client(object_name)
|
||||
blob_client.upload_blob(data, overwrite=True)
|
||||
|
||||
def exists(self, object_name):
|
||||
container_client = self.__provide_container_client()
|
||||
blob_client = container_client.get_blob_client(object_name)
|
||||
return blob_client.exists()
|
||||
|
||||
@retry(tries=3, delay=5, jitter=(1, 3))
|
||||
def get_object(self, object_name):
|
||||
logger.debug(f"Downloading '{object_name}'...")
|
||||
|
||||
try:
|
||||
container_client = self.__provide_container_client()
|
||||
blob_client = container_client.get_blob_client(object_name)
|
||||
blob_data = blob_client.download_blob()
|
||||
return blob_data.readall()
|
||||
except Exception as err:
|
||||
raise Exception("Failed getting object from azure client") from err
|
||||
|
||||
def get_all_objects(self):
|
||||
container_client = self.__provide_container_client()
|
||||
blobs = container_client.list_blobs()
|
||||
for blob in blobs:
|
||||
logger.debug(f"Downloading '{blob.name}'...")
|
||||
blob_client = container_client.get_blob_client(blob)
|
||||
blob_data = blob_client.download_blob()
|
||||
data = blob_data.readall()
|
||||
yield data
|
||||
|
||||
def clear_bucket(self):
|
||||
logger.debug(f"Clearing Azure container '{self.bucket}'...")
|
||||
container_client = self._client.get_container_client(self.bucket)
|
||||
blobs = container_client.list_blobs()
|
||||
container_client.delete_blobs(*blobs)
|
||||
|
||||
def get_all_object_names(self):
|
||||
container_client = self.__provide_container_client()
|
||||
blobs = container_client.list_blobs()
|
||||
return zip(repeat(self.bucket), map(attrgetter("name"), blobs))
|
||||
|
||||
|
||||
def get_azure_storage_from_settings(settings: Dynaconf):
|
||||
validate_settings(settings, azure_storage_validators)
|
||||
|
||||
return AzureStorage(
|
||||
client=BlobServiceClient.from_connection_string(conn_str=settings.storage.azure.connection_string),
|
||||
bucket=settings.storage.azure.container,
|
||||
)
|
||||
@ -1,89 +0,0 @@
|
||||
import io
|
||||
from itertools import repeat
|
||||
from operator import attrgetter
|
||||
|
||||
from dynaconf import Dynaconf
|
||||
from kn_utils.logging import logger
|
||||
from minio import Minio
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.config.loader import validate_settings
|
||||
from pyinfra.config.validators import s3_storage_validators
|
||||
from pyinfra.storage.storages.storage import Storage
|
||||
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
|
||||
|
||||
|
||||
class S3Storage(Storage):
|
||||
def __init__(self, client: Minio, bucket: str):
|
||||
self._client = client
|
||||
self._bucket = bucket
|
||||
|
||||
@property
|
||||
def bucket(self):
|
||||
return self._bucket
|
||||
|
||||
def make_bucket(self):
|
||||
if not self.has_bucket():
|
||||
self._client.make_bucket(self.bucket)
|
||||
|
||||
def has_bucket(self):
|
||||
return self._client.bucket_exists(self.bucket)
|
||||
|
||||
def put_object(self, object_name, data):
|
||||
logger.debug(f"Uploading '{object_name}'...")
|
||||
data = io.BytesIO(data)
|
||||
self._client.put_object(self.bucket, object_name, data, length=data.getbuffer().nbytes)
|
||||
|
||||
def exists(self, object_name):
|
||||
try:
|
||||
self._client.stat_object(self.bucket, object_name)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@retry(tries=3, delay=5, jitter=(1, 3))
|
||||
def get_object(self, object_name):
|
||||
logger.debug(f"Downloading '{object_name}'...")
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = self._client.get_object(self.bucket, object_name)
|
||||
return response.data
|
||||
except Exception as err:
|
||||
raise Exception("Failed getting object from s3 client") from err
|
||||
finally:
|
||||
if response:
|
||||
response.close()
|
||||
response.release_conn()
|
||||
|
||||
def get_all_objects(self):
|
||||
for obj in self._client.list_objects(self.bucket, recursive=True):
|
||||
logger.debug(f"Downloading '{obj.object_name}'...")
|
||||
yield self.get_object(obj.object_name)
|
||||
|
||||
def clear_bucket(self):
|
||||
logger.debug(f"Clearing S3 bucket '{self.bucket}'...")
|
||||
objects = self._client.list_objects(self.bucket, recursive=True)
|
||||
for obj in objects:
|
||||
self._client.remove_object(self.bucket, obj.object_name)
|
||||
|
||||
def get_all_object_names(self):
|
||||
objs = self._client.list_objects(self.bucket, recursive=True)
|
||||
return zip(repeat(self.bucket), map(attrgetter("object_name"), objs))
|
||||
|
||||
|
||||
def get_s3_storage_from_settings(settings: Dynaconf):
|
||||
validate_settings(settings, s3_storage_validators)
|
||||
|
||||
secure, endpoint = validate_and_parse_s3_endpoint(settings.storage.s3.endpoint)
|
||||
|
||||
return S3Storage(
|
||||
client=Minio(
|
||||
secure=secure,
|
||||
endpoint=endpoint,
|
||||
access_key=settings.storage.s3.key,
|
||||
secret_key=settings.storage.s3.secret,
|
||||
region=settings.storage.s3.region,
|
||||
),
|
||||
bucket=settings.storage.s3.bucket,
|
||||
)
|
||||
@ -1,40 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Storage(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def bucket(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def make_bucket(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def has_bucket(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def put_object(self, object_name, data):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, object_name):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_object(self, object_name):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_all_objects(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear_bucket(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_all_object_names(self):
|
||||
raise NotImplementedError
|
||||
@ -1,150 +0,0 @@
|
||||
import gzip
|
||||
import json
|
||||
from functools import singledispatch
|
||||
from typing import TypedDict
|
||||
|
||||
from kn_utils.logging import logger
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from pyinfra.storage.storages.storage import Storage
|
||||
|
||||
|
||||
class DossierIdFileIdDownloadPayload(BaseModel):
|
||||
dossierId: str
|
||||
fileId: str
|
||||
targetFileExtension: str
|
||||
|
||||
@property
|
||||
def targetFilePath(self):
|
||||
return f"{self.dossierId}/{self.fileId}.{self.targetFileExtension}"
|
||||
|
||||
|
||||
class TenantIdDossierIdFileIdDownloadPayload(BaseModel):
|
||||
tenantId: str
|
||||
dossierId: str
|
||||
fileId: str
|
||||
targetFileExtension: str
|
||||
|
||||
@property
|
||||
def targetFilePath(self):
|
||||
return f"{self.tenantId}/{self.dossierId}/{self.fileId}.{self.targetFileExtension}"
|
||||
|
||||
|
||||
class DossierIdFileIdUploadPayload(BaseModel):
|
||||
dossierId: str
|
||||
fileId: str
|
||||
responseFileExtension: str
|
||||
|
||||
@property
|
||||
def responseFilePath(self):
|
||||
return f"{self.dossierId}/{self.fileId}.{self.responseFileExtension}"
|
||||
|
||||
|
||||
class TenantIdDossierIdFileIdUploadPayload(BaseModel):
|
||||
tenantId: str
|
||||
dossierId: str
|
||||
fileId: str
|
||||
responseFileExtension: str
|
||||
|
||||
@property
|
||||
def responseFilePath(self):
|
||||
return f"{self.tenantId}/{self.dossierId}/{self.fileId}.{self.responseFileExtension}"
|
||||
|
||||
|
||||
class TargetResponseFilePathDownloadPayload(BaseModel):
|
||||
targetFilePath: str | dict[str, str]
|
||||
|
||||
|
||||
class TargetResponseFilePathUploadPayload(BaseModel):
|
||||
responseFilePath: str
|
||||
|
||||
|
||||
class DownloadedData(TypedDict):
|
||||
data: bytes
|
||||
file_path: str
|
||||
|
||||
|
||||
def download_data_bytes_as_specified_in_message(
|
||||
storage: Storage, raw_payload: dict
|
||||
) -> dict[str, DownloadedData] | DownloadedData:
|
||||
"""Convenience function to download a file specified in a message payload.
|
||||
Supports both legacy and new payload formats. Also supports downloading multiple files at once, which should
|
||||
be specified in a dictionary under the 'targetFilePath' key with the file path as value.
|
||||
The data is downloaded as bytes and returned as a dictionary with the file path as key and the data as value.
|
||||
In case of several download targets, a nested dictionary is returned with the same keys and dictionaries with
|
||||
the file path and data as values.
|
||||
"""
|
||||
|
||||
try:
|
||||
if "tenantId" in raw_payload and "dossierId" in raw_payload:
|
||||
payload = TenantIdDossierIdFileIdDownloadPayload(**raw_payload)
|
||||
elif "tenantId" not in raw_payload and "dossierId" in raw_payload:
|
||||
payload = DossierIdFileIdDownloadPayload(**raw_payload)
|
||||
else:
|
||||
payload = TargetResponseFilePathDownloadPayload(**raw_payload)
|
||||
except ValidationError:
|
||||
raise ValueError("No download file path found in payload, nothing to download.")
|
||||
|
||||
data = _download(payload.targetFilePath, storage)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@singledispatch
|
||||
def _download(
|
||||
file_path_or_file_path_dict: str | dict[str, str], storage: Storage
|
||||
) -> dict[str, DownloadedData] | DownloadedData:
|
||||
pass
|
||||
|
||||
|
||||
@_download.register(str)
|
||||
def _download_single_file(file_path: str, storage: Storage) -> DownloadedData:
|
||||
if not storage.exists(file_path):
|
||||
raise FileNotFoundError(f"File '{file_path}' does not exist in storage.")
|
||||
|
||||
data = storage.get_object(file_path)
|
||||
logger.info(f"Downloaded {file_path} from storage.")
|
||||
|
||||
return DownloadedData(data=data, file_path=file_path)
|
||||
|
||||
|
||||
@_download.register(dict)
|
||||
def _download_multiple_files(file_path_dict: dict, storage: Storage) -> dict[str, DownloadedData]:
|
||||
return {key: _download(value, storage) for key, value in file_path_dict.items()}
|
||||
|
||||
|
||||
def upload_data_as_specified_in_message(storage: Storage, raw_payload: dict, data):
|
||||
"""Convenience function to upload a file specified in a message payload. For now, only json serializable data is
|
||||
supported. The storage json consists of the raw_payload, which is extended with a 'data' key, containing the
|
||||
data to be uploaded.
|
||||
|
||||
If the content is not a json serializable object, an exception will be raised.
|
||||
If the result file identifier specifies compression with gzip (.gz), it will be compressed before upload.
|
||||
|
||||
This function can be extended in the future as needed (e.g. if we need to upload images), but since further
|
||||
requirements are not specified at this point in time, and it is unclear what these would entail, the code is kept
|
||||
simple for now to improve readability, maintainability and avoid refactoring efforts of generic solutions that
|
||||
weren't as generic as they seemed.
|
||||
"""
|
||||
|
||||
try:
|
||||
if "tenantId" in raw_payload and "dossierId" in raw_payload:
|
||||
payload = TenantIdDossierIdFileIdUploadPayload(**raw_payload)
|
||||
elif "tenantId" not in raw_payload and "dossierId" in raw_payload:
|
||||
payload = DossierIdFileIdUploadPayload(**raw_payload)
|
||||
else:
|
||||
payload = TargetResponseFilePathUploadPayload(**raw_payload)
|
||||
except ValidationError:
|
||||
raise ValueError("No upload file path found in payload, nothing to upload.")
|
||||
|
||||
if ".json" not in payload.responseFilePath:
|
||||
raise ValueError("Only json serializable data can be uploaded.")
|
||||
|
||||
data = {**raw_payload, "data": data}
|
||||
|
||||
data = json.dumps(data).encode("utf-8")
|
||||
data = gzip.compress(data) if ".gz" in payload.responseFilePath else data
|
||||
|
||||
storage.put_object(payload.responseFilePath, data)
|
||||
|
||||
logger.info(f"Uploaded {payload.responseFilePath} to storage.")
|
||||
@ -1,49 +0,0 @@
|
||||
import base64
|
||||
import os
|
||||
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
|
||||
def build_aes_gcm_cipher(public_key, iv=None):
|
||||
encoded_key = public_key.encode("utf-8")
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA1(),
|
||||
length=16,
|
||||
salt=iv,
|
||||
iterations=65536,
|
||||
)
|
||||
private_key = kdf.derive(encoded_key)
|
||||
return AESGCM(private_key)
|
||||
|
||||
|
||||
def encrypt(public_key: str, plaintext: str, iv: int = None) -> str:
|
||||
"""Encrypt a text with AES/GCS using a public key.
|
||||
|
||||
The byte-converted ciphertext consists of an unsigned 32-bit integer big-endian byteorder header i.e. the first 4
|
||||
bytes, specifying the length of the following initialization vector (iv). The rest of the text contains the
|
||||
encrypted message.
|
||||
"""
|
||||
iv = iv or os.urandom(12)
|
||||
plaintext_bytes = plaintext.encode("utf-8")
|
||||
cipher = build_aes_gcm_cipher(public_key, iv)
|
||||
header = len(iv).to_bytes(length=4, byteorder="big")
|
||||
encrypted = header + iv + cipher.encrypt(nonce=iv, data=plaintext_bytes, associated_data=None)
|
||||
return base64.b64encode(encrypted).decode("utf-8")
|
||||
|
||||
|
||||
def decrypt(public_key: str, ciphertext: str) -> str:
|
||||
"""Decrypt an AES/GCS encrypted text with a public key.
|
||||
|
||||
The byte-converted ciphertext consists of an unsigned 32-bit integer big-endian byteorder header i.e. the first 4
|
||||
bytes, specifying the length of the following initialization vector (iv). The rest of the text contains the
|
||||
encrypted message.
|
||||
"""
|
||||
ciphertext_bytes = base64.b64decode(ciphertext)
|
||||
header, rest = ciphertext_bytes[:4], ciphertext_bytes[4:]
|
||||
iv_length = int.from_bytes(header, "big")
|
||||
iv, ciphertext_bytes = rest[:iv_length], rest[iv_length:]
|
||||
cipher = build_aes_gcm_cipher(public_key, iv)
|
||||
decrypted_text = cipher.decrypt(nonce=iv, data=ciphertext_bytes, associated_data=None)
|
||||
return decrypted_text.decode("utf-8")
|
||||
@ -1,96 +0,0 @@
|
||||
import json
|
||||
|
||||
from azure.monitor.opentelemetry import configure_azure_monitor
|
||||
from dynaconf import Dynaconf
|
||||
from fastapi import FastAPI
|
||||
from kn_utils.logging import logger
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from opentelemetry.instrumentation.pika import PikaInstrumentor
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import (
|
||||
BatchSpanProcessor,
|
||||
ConsoleSpanExporter,
|
||||
SpanExporter,
|
||||
SpanExportResult,
|
||||
)
|
||||
|
||||
from pyinfra.config.loader import validate_settings
|
||||
from pyinfra.config.validators import opentelemetry_validators
|
||||
|
||||
|
||||
class JsonSpanExporter(SpanExporter):
|
||||
def __init__(self):
|
||||
self.traces = []
|
||||
|
||||
def export(self, spans):
|
||||
for span in spans:
|
||||
self.traces.append(json.loads(span.to_json()))
|
||||
return SpanExportResult.SUCCESS
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
|
||||
def setup_trace(settings: Dynaconf, service_name: str = None, exporter: SpanExporter = None):
|
||||
tracing_type = settings.tracing.type
|
||||
if tracing_type == "azure_monitor":
|
||||
# Configure OpenTelemetry to use Azure Monitor with the
|
||||
# APPLICATIONINSIGHTS_CONNECTION_STRING environment variable.
|
||||
try:
|
||||
configure_azure_monitor()
|
||||
logger.info("Azure Monitor tracing enabled.")
|
||||
except Exception as exception:
|
||||
logger.warning(f"Azure Monitor tracing could not be enabled: {exception}")
|
||||
elif tracing_type == "opentelemetry":
|
||||
configure_opentelemtry_tracing(settings, service_name, exporter)
|
||||
logger.info("OpenTelemetry tracing enabled.")
|
||||
else:
|
||||
logger.warning(f"Unknown tracing type: {tracing_type}. Tracing could not be enabled.")
|
||||
|
||||
|
||||
def configure_opentelemtry_tracing(settings: Dynaconf, service_name: str = None, exporter: SpanExporter = None):
|
||||
service_name = service_name or settings.tracing.opentelemetry.service_name
|
||||
exporter = exporter or get_exporter(settings)
|
||||
|
||||
resource = Resource(attributes={"service.name": service_name})
|
||||
provider = TracerProvider(resource=resource, shutdown_on_exit=True)
|
||||
|
||||
processor = BatchSpanProcessor(exporter)
|
||||
provider.add_span_processor(processor)
|
||||
|
||||
# TODO: trace.set_tracer_provider produces a warning if trying to set the provider twice.
|
||||
# "WARNING opentelemetry.trace:__init__.py:521 Overriding of current TracerProvider is not allowed"
|
||||
# This doesn't seem to affect the functionality since we only want to use the tracer provided set in the beginning.
|
||||
# We work around the log message by using the protected method with log=False.
|
||||
trace._set_tracer_provider(provider, log=False)
|
||||
|
||||
|
||||
def get_exporter(settings: Dynaconf):
|
||||
validate_settings(settings, validators=opentelemetry_validators)
|
||||
|
||||
if settings.tracing.opentelemetry.exporter == "json":
|
||||
return JsonSpanExporter()
|
||||
elif settings.tracing.opentelemetry.exporter == "otlp":
|
||||
return OTLPSpanExporter(endpoint=settings.tracing.opentelemetry.endpoint)
|
||||
elif settings.tracing.opentelemetry.exporter == "console":
|
||||
return ConsoleSpanExporter()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid OpenTelemetry exporter {settings.tracing.opentelemetry.exporter}. "
|
||||
f"Valid values are 'json', 'otlp' and 'console'."
|
||||
)
|
||||
|
||||
|
||||
def instrument_pika(dynamic_queues: bool):
|
||||
if dynamic_queues:
|
||||
AioPikaInstrumentor().instrument()
|
||||
else:
|
||||
PikaInstrumentor().instrument()
|
||||
|
||||
|
||||
def instrument_app(app: FastAPI, excluded_urls: str = "/health,/ready,/prometheus"):
|
||||
FastAPIInstrumentor().instrument_app(app, excluded_urls=excluded_urls)
|
||||
@ -1,40 +0,0 @@
|
||||
import re
|
||||
from operator import truth
|
||||
from typing import Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def make_url_validator(allowed_connection_schemes: tuple = ("http", "https")):
|
||||
pattern = re.compile(
|
||||
r"^(("
|
||||
+ r"([A-Za-z]{3,9}:(?:\/\/)?)"
|
||||
+ r"(?:[\-;:&=\+\$,\w]+@)?"
|
||||
+ r"[A-Za-z0-9\.\-]+|(?:www\.|[\-;:&=\+\$,\w]+@)"
|
||||
+ r"[A-Za-z0-9\.\-]+)"
|
||||
+ r"((?:\/[\+~%\/\.\w\-_]*)?"
|
||||
+ r"\??(?:[\-\+=&;%@\.\w_]*)#?(?:[\.\!\/\\\w]*))?)"
|
||||
)
|
||||
|
||||
def inner(url: str):
|
||||
url_is_valid = pattern.match(url)
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
endpoint_is_valid = truth(parsed_url.netloc)
|
||||
protocol_is_valid = parsed_url.scheme in allowed_connection_schemes
|
||||
|
||||
return url_is_valid and endpoint_is_valid and protocol_is_valid
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def validate_and_parse_s3_endpoint(endpoint: str) -> Tuple[bool, str]:
|
||||
validate_url = make_url_validator()
|
||||
|
||||
if not validate_url(endpoint):
|
||||
raise Exception(f"The s3 storage endpoint is not a valid url: {endpoint}")
|
||||
|
||||
parsed_url = urlparse(endpoint)
|
||||
connection_is_secure = parsed_url.scheme == "https"
|
||||
storage_endpoint = parsed_url.netloc
|
||||
|
||||
return connection_is_secure, storage_endpoint
|
||||
@ -1,64 +0,0 @@
|
||||
from time import time
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
from dynaconf import Dynaconf
|
||||
from fastapi import FastAPI
|
||||
from funcy import identity
|
||||
from prometheus_client import REGISTRY, CollectorRegistry, Summary, generate_latest
|
||||
from starlette.responses import Response
|
||||
|
||||
from pyinfra.config.loader import validate_settings
|
||||
from pyinfra.config.validators import prometheus_validators
|
||||
|
||||
|
||||
def add_prometheus_endpoint(app: FastAPI, registry: CollectorRegistry = REGISTRY) -> FastAPI:
|
||||
"""Add a prometheus endpoint to the app. It is recommended to use the default global registry.
|
||||
You can register your own metrics with it anywhere, and they will be scraped with this endpoint.
|
||||
See https://prometheus.io/docs/concepts/metric_types/ for the different metric types.
|
||||
The implementation for monitoring the processing time of a function is in the decorator below (decorate the
|
||||
processing function of a service to assess the processing time of each call).
|
||||
|
||||
The convention for the metric name is {product_name}_{service_name}_{parameter_to_monitor}.
|
||||
"""
|
||||
|
||||
@app.get("/prometheus")
|
||||
def prometheus_metrics():
|
||||
return Response(generate_latest(registry), media_type="text/plain")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
Decorator = TypeVar("Decorator", bound=Callable[[Callable], Callable])
|
||||
|
||||
|
||||
def make_prometheus_processing_time_decorator_from_settings(
|
||||
settings: Dynaconf,
|
||||
postfix: str = "processing_time",
|
||||
registry: CollectorRegistry = REGISTRY,
|
||||
) -> Decorator:
|
||||
"""Make a decorator for monitoring the processing time of a function. This, and other metrics should follow the
|
||||
convention {product name}_{service name}_{processing step / parameter to monitor}.
|
||||
"""
|
||||
validate_settings(settings, validators=prometheus_validators)
|
||||
|
||||
processing_time_sum = Summary(
|
||||
f"{settings.metrics.prometheus.prefix}_{postfix}",
|
||||
"Summed up processing time per call.",
|
||||
registry=registry,
|
||||
)
|
||||
|
||||
def decorator(process_fn: Callable) -> Callable:
|
||||
def inner(*args, **kwargs):
|
||||
start = time()
|
||||
|
||||
result = process_fn(*args, **kwargs)
|
||||
|
||||
runtime = time() - start
|
||||
|
||||
processing_time_sum.observe(runtime)
|
||||
|
||||
return result
|
||||
|
||||
return inner
|
||||
|
||||
return decorator
|
||||
@ -1,103 +0,0 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
import uvicorn
|
||||
from dynaconf import Dynaconf
|
||||
from fastapi import FastAPI
|
||||
from kn_utils.logging import logger
|
||||
from kn_utils.retry import retry
|
||||
|
||||
from pyinfra.config.loader import validate_settings
|
||||
from pyinfra.config.validators import webserver_validators
|
||||
|
||||
|
||||
class PyInfraUvicornServer(uvicorn.Server):
|
||||
# this is a workaround to enable custom signal handlers
|
||||
# https://github.com/encode/uvicorn/issues/1579
|
||||
def install_signal_handlers(self):
|
||||
pass
|
||||
|
||||
|
||||
@retry(
|
||||
tries=5,
|
||||
exceptions=Exception,
|
||||
reraise=True,
|
||||
)
|
||||
def create_webserver_thread_from_settings(app: FastAPI, settings: Dynaconf) -> threading.Thread:
|
||||
validate_settings(settings, validators=webserver_validators)
|
||||
return create_webserver_thread(app=app, port=settings.webserver.port, host=settings.webserver.host)
|
||||
|
||||
|
||||
def create_webserver_thread(app: FastAPI, port: int, host: str) -> threading.Thread:
|
||||
"""Creates a thread that runs a FastAPI webserver. Start with thread.start(), and join with thread.join().
|
||||
Note that the thread is a daemon thread, so it will be terminated when the main thread is terminated.
|
||||
"""
|
||||
|
||||
def run_server():
|
||||
retries = 5
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
uvicorn.run(app, port=port, host=host, log_level=logging.WARNING)
|
||||
break
|
||||
except Exception as e:
|
||||
if attempt < retries - 1: # if it's not the last attempt
|
||||
logger.warning(f"Attempt {attempt + 1} failed to start the server: {e}. Retrying...")
|
||||
time.sleep(2**attempt) # exponential backoff
|
||||
else:
|
||||
logger.error(f"Failed to start the server after {retries} attempts: {e}")
|
||||
raise
|
||||
|
||||
thread = threading.Thread(target=run_server)
|
||||
thread.daemon = True
|
||||
return thread
|
||||
|
||||
|
||||
async def run_async_webserver(app: FastAPI, port: int, host: str):
|
||||
"""Run the FastAPI web server async."""
|
||||
config = uvicorn.Config(app, host=host, port=port, log_level=logging.WARNING)
|
||||
server = PyInfraUvicornServer(config)
|
||||
|
||||
try:
|
||||
await server.serve()
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Webserver was cancelled.")
|
||||
server.should_exit = True
|
||||
await server.shutdown()
|
||||
except Exception as e:
|
||||
logger.error(f"Error while running the webserver: {e}", exc_info=True)
|
||||
finally:
|
||||
logger.info("Webserver has been shut down.")
|
||||
|
||||
|
||||
HealthFunction = Callable[[], bool]
|
||||
|
||||
|
||||
def add_health_check_endpoint(app: FastAPI, health_function: HealthFunction) -> FastAPI:
|
||||
"""Add a health check endpoint to the app. The health function should return True if the service is healthy,
|
||||
and False otherwise. The health function is called when the endpoint is hit.
|
||||
"""
|
||||
if inspect.iscoroutinefunction(health_function):
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/ready")
|
||||
async def async_check_health():
|
||||
alive = await health_function()
|
||||
if alive:
|
||||
return {"status": "OK"}, 200
|
||||
return {"status": "Service Unavailable"}, 503
|
||||
|
||||
else:
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/ready")
|
||||
def check_health():
|
||||
if health_function():
|
||||
return {"status": "OK"}, 200
|
||||
return {"status": "Service Unavailable"}, 503
|
||||
|
||||
return app
|
||||
108
pyproject.toml
108
pyproject.toml
@ -1,103 +1,41 @@
|
||||
[tool.poetry]
|
||||
name = "pyinfra"
|
||||
version = "4.1.0"
|
||||
version = "1.3.2"
|
||||
description = ""
|
||||
authors = ["Team Research <research@knecon.com>"]
|
||||
authors = ["Francisco Schulz <francisco.schulz@iqser.com>"]
|
||||
license = "All rights reseverd"
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<3.11"
|
||||
# infra, deployment
|
||||
pika = "^1.3"
|
||||
retry = "^0.9"
|
||||
minio = "^7.1"
|
||||
prometheus-client = "^0.18"
|
||||
# azure
|
||||
azure-core = "^1.29"
|
||||
azure-storage-blob = "^12.13"
|
||||
# misc utils
|
||||
funcy = "^2"
|
||||
pycryptodome = "^3.19"
|
||||
fastapi = "^0.109.0"
|
||||
uvicorn = "^0.26.0"
|
||||
python = "~3.8"
|
||||
pika = "^1.2.0"
|
||||
retry = "^0.9.2"
|
||||
minio = "^7.1.3"
|
||||
azure-core = "^1.22.1"
|
||||
azure-storage-blob = "^12.9.0"
|
||||
testcontainers = "^3.4.2"
|
||||
docker-compose = "^1.29.2"
|
||||
funcy = "^1.17"
|
||||
prometheus-client = "^0.16.0"
|
||||
|
||||
# DONT USE GROUPS BECAUSE THEY ARE NOT INSTALLED FOR PACKAGES
|
||||
# [tool.poetry.group.internal.dependencies] <<< THIS IS NOT WORKING
|
||||
kn-utils = { version = ">=0.4.0", source = "nexus" }
|
||||
# We set all opentelemetry dependencies to lower bound because the image classification service depends on a protobuf version <4, but does not use proto files.
|
||||
# Therefore, we allow latest possible protobuf version in the services which use proto files. As soon as the dependency issue is fixed set this to the latest possible opentelemetry version
|
||||
opentelemetry-instrumentation-pika = ">=0.46b0,<0.50"
|
||||
opentelemetry-exporter-otlp = ">=1.25.0,<1.29"
|
||||
opentelemetry-instrumentation = ">=0.46b0,<0.50"
|
||||
opentelemetry-api = ">=1.25.0,<1.29"
|
||||
opentelemetry-sdk = ">=1.25.0,<1.29"
|
||||
opentelemetry-exporter-otlp-proto-http = ">=1.25.0,<1.29"
|
||||
opentelemetry-instrumentation-flask = ">=0.46b0,<0.50"
|
||||
opentelemetry-instrumentation-requests = ">=0.46b0,<0.50"
|
||||
opentelemetry-instrumentation-fastapi = ">=0.46b0,<0.50"
|
||||
opentelemetry-instrumentation-aio-pika = ">=0.46b0,<0.50"
|
||||
wcwidth = "<=0.2.12"
|
||||
azure-monitor-opentelemetry = "^1.6.0"
|
||||
aio-pika = "^9.4.2"
|
||||
aiohttp = "^3.9.5"
|
||||
|
||||
# THIS IS NOT AVAILABLE FOR SERVICES THAT IMPLEMENT PYINFRA
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^7"
|
||||
ipykernel = "^6.26.0"
|
||||
black = "^24.10"
|
||||
pylint = "^3"
|
||||
coverage = "^7.3"
|
||||
requests = "^2.31"
|
||||
pre-commit = "^3.6.0"
|
||||
cyclonedx-bom = "^4.1.1"
|
||||
dvc = "^3.51.2"
|
||||
dvc-azure = "^3.1.0"
|
||||
deepdiff = "^7.0.1"
|
||||
pytest-cov = "^5.0.0"
|
||||
pytest = "^7.1.3"
|
||||
ipykernel = "^6.16.0"
|
||||
black = {version = "^23.1a1", allow-prereleases = true}
|
||||
pylint = "^2.15.10"
|
||||
coverage = "^7.2.0"
|
||||
requests = "^2.28.2"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "6.0"
|
||||
addopts = "-ra -q"
|
||||
testpaths = ["tests", "integration"]
|
||||
testpaths = [
|
||||
"tests",
|
||||
"integration",
|
||||
]
|
||||
log_cli = 1
|
||||
log_cli_level = "DEBUG"
|
||||
|
||||
[tool.mypy]
|
||||
exclude = ['.venv']
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
target-version = ["py310"]
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
[tool.pylint.format]
|
||||
max-line-length = 120
|
||||
disable = [
|
||||
"C0114",
|
||||
"C0325",
|
||||
"R0801",
|
||||
"R0902",
|
||||
"R0903",
|
||||
"R0904",
|
||||
"R0913",
|
||||
"R0914",
|
||||
"W0511",
|
||||
]
|
||||
docstring-min-length = 3
|
||||
|
||||
[[tool.poetry.source]]
|
||||
name = "pypi-proxy"
|
||||
url = "https://nexus.knecon.com/repository/pypi-proxy/simple"
|
||||
priority = "primary"
|
||||
|
||||
[[tool.poetry.source]]
|
||||
name = "nexus"
|
||||
url = "https://nexus.knecon.com/repository/python/simple"
|
||||
priority = "explicit"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
4
pytest.ini
Normal file
4
pytest.ini
Normal file
@ -0,0 +1,4 @@
|
||||
[pytest]
|
||||
log_cli = 1
|
||||
log_cli_level = DEBUG
|
||||
|
||||
9
requirements.txt
Executable file
9
requirements.txt
Executable file
@ -0,0 +1,9 @@
|
||||
pika==1.2.0
|
||||
retry==0.9.2
|
||||
minio==7.1.3
|
||||
azure-core==1.22.1
|
||||
azure-storage-blob==12.9.0
|
||||
testcontainers==3.4.2
|
||||
docker-compose==1.29.2
|
||||
pytest~=7.0.1
|
||||
funcy==1.17
|
||||
9
run_tests.sh
Normal file
9
run_tests.sh
Normal file
@ -0,0 +1,9 @@
|
||||
echo "${bamboo_nexus_password}" | docker login --username "${bamboo_nexus_user}" --password-stdin nexus.iqser.com:5001
|
||||
docker build -f Dockerfile_tests -t pyinfra-tests .
|
||||
|
||||
rnd=$(date +"%s")
|
||||
name=pyinfra-tests-${rnd}
|
||||
|
||||
echo "running tests container"
|
||||
|
||||
docker run --rm --net=host --name $name -v $PWD:$PWD -w $PWD -v /var/run/docker.sock:/var/run/docker.sock pyinfra-tests
|
||||
@ -1,150 +0,0 @@
|
||||
import asyncio
|
||||
import gzip
|
||||
import json
|
||||
from operator import itemgetter
|
||||
from typing import Any, Dict
|
||||
|
||||
from aio_pika import Message
|
||||
from aio_pika.abc import AbstractIncomingMessage
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
|
||||
from pyinfra.queue.async_manager import AsyncQueueManager, RabbitMQConfig
|
||||
from pyinfra.storage.storages.s3 import S3Storage, get_s3_storage_from_settings
|
||||
|
||||
settings = load_settings(local_pyinfra_root_path / "config/")
|
||||
|
||||
|
||||
async def dummy_message_processor(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
logger.info(f"Processing message: {message}")
|
||||
# await asyncio.sleep(1) # Simulate processing time
|
||||
|
||||
storage = get_s3_storage_from_settings(settings)
|
||||
tenant_id, dossier_id, file_id = itemgetter("tenantId", "dossierId", "fileId")(message)
|
||||
suffix = message["responseFileExtension"]
|
||||
|
||||
object_name = f"{tenant_id}/{dossier_id}/{file_id}.{message['targetFileExtension']}"
|
||||
original_content = json.loads(gzip.decompress(storage.get_object(object_name)))
|
||||
processed_content = {
|
||||
"processedPages": original_content["numberOfPages"],
|
||||
"processedSectionTexts": f"Processed: {original_content['sectionTexts']}",
|
||||
}
|
||||
|
||||
processed_object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}"
|
||||
processed_data = gzip.compress(json.dumps(processed_content).encode("utf-8"))
|
||||
storage.put_object(processed_object_name, processed_data)
|
||||
|
||||
processed_message = message.copy()
|
||||
processed_message["processed"] = True
|
||||
processed_message["processor_message"] = "This message was processed by the dummy processor"
|
||||
|
||||
logger.info(f"Finished processing message. Result: {processed_message}")
|
||||
return processed_message
|
||||
|
||||
|
||||
async def on_response_message_callback(storage: S3Storage):
|
||||
async def on_message(message: AbstractIncomingMessage) -> None:
|
||||
async with message.process(ignore_processed=True):
|
||||
if not message.body:
|
||||
raise ValueError
|
||||
response = json.loads(message.body)
|
||||
logger.info(f"Received {response}")
|
||||
logger.info(f"Message headers: {message.properties.headers}")
|
||||
await message.ack()
|
||||
tenant_id, dossier_id, file_id = itemgetter("tenantId", "dossierId", "fileId")(response)
|
||||
suffix = response["responseFileExtension"]
|
||||
result = storage.get_object(f"{tenant_id}/{dossier_id}/{file_id}.{suffix}")
|
||||
result = json.loads(gzip.decompress(result))
|
||||
logger.info(f"Contents of result on storage: {result}")
|
||||
|
||||
return on_message
|
||||
|
||||
|
||||
def upload_json_and_make_message_body(tenant_id: str):
|
||||
dossier_id, file_id, suffix = "dossier", "file", "json.gz"
|
||||
content = {
|
||||
"numberOfPages": 7,
|
||||
"sectionTexts": "data",
|
||||
}
|
||||
|
||||
object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}"
|
||||
data = gzip.compress(json.dumps(content).encode("utf-8"))
|
||||
|
||||
storage = get_s3_storage_from_settings(settings)
|
||||
if not storage.has_bucket():
|
||||
storage.make_bucket()
|
||||
storage.put_object(object_name, data)
|
||||
|
||||
message_body = {
|
||||
"tenantId": tenant_id,
|
||||
"dossierId": dossier_id,
|
||||
"fileId": file_id,
|
||||
"targetFileExtension": suffix,
|
||||
"responseFileExtension": f"result.{suffix}",
|
||||
}
|
||||
return message_body, storage
|
||||
|
||||
|
||||
async def test_rabbitmq_handler() -> None:
|
||||
tenant_service_url = settings.storage.tenant_server.endpoint
|
||||
|
||||
config = RabbitMQConfig(
|
||||
host=settings.rabbitmq.host,
|
||||
port=settings.rabbitmq.port,
|
||||
username=settings.rabbitmq.username,
|
||||
password=settings.rabbitmq.password,
|
||||
heartbeat=settings.rabbitmq.heartbeat,
|
||||
input_queue_prefix=settings.rabbitmq.service_request_queue_prefix,
|
||||
tenant_event_queue_suffix=settings.rabbitmq.tenant_event_queue_suffix,
|
||||
tenant_exchange_name=settings.rabbitmq.tenant_exchange_name,
|
||||
service_request_exchange_name=settings.rabbitmq.service_request_exchange_name,
|
||||
service_response_exchange_name=settings.rabbitmq.service_response_exchange_name,
|
||||
service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name,
|
||||
queue_expiration_time=settings.rabbitmq.queue_expiration_time,
|
||||
pod_name=settings.kubernetes.pod_name,
|
||||
)
|
||||
|
||||
handler = AsyncQueueManager(config, tenant_service_url, dummy_message_processor)
|
||||
|
||||
await handler.connect()
|
||||
await handler.setup_exchanges()
|
||||
|
||||
tenant_id = "test_tenant"
|
||||
|
||||
# Test tenant creation
|
||||
create_message = {"tenantId": tenant_id}
|
||||
await handler.tenant_exchange.publish(
|
||||
Message(body=json.dumps(create_message).encode()), routing_key="tenant.created"
|
||||
)
|
||||
logger.info(f"Sent create tenant message for {tenant_id}")
|
||||
await asyncio.sleep(0.5) # Wait for queue creation
|
||||
|
||||
# Prepare service request
|
||||
service_request, storage = upload_json_and_make_message_body(tenant_id)
|
||||
|
||||
# Test service request
|
||||
await handler.input_exchange.publish(Message(body=json.dumps(service_request).encode()), routing_key=tenant_id)
|
||||
logger.info(f"Sent service request for {tenant_id}")
|
||||
await asyncio.sleep(5) # Wait for message processing
|
||||
|
||||
# Consume service request
|
||||
response_queue = await handler.channel.declare_queue(name=f"response_queue_{tenant_id}")
|
||||
await response_queue.bind(exchange=handler.output_exchange, routing_key=tenant_id)
|
||||
callback = await on_response_message_callback(storage)
|
||||
await response_queue.consume(callback=callback)
|
||||
|
||||
await asyncio.sleep(5) # Wait for message processing
|
||||
|
||||
# Test tenant deletion
|
||||
delete_message = {"tenantId": tenant_id}
|
||||
await handler.tenant_exchange.publish(
|
||||
Message(body=json.dumps(delete_message).encode()), routing_key="tenant.delete"
|
||||
)
|
||||
logger.info(f"Sent delete tenant message for {tenant_id}")
|
||||
await asyncio.sleep(0.5) # Wait for queue deletion
|
||||
|
||||
await handler.connection.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_rabbitmq_handler())
|
||||
@ -1,67 +0,0 @@
|
||||
import gzip
|
||||
import json
|
||||
from operator import itemgetter
|
||||
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
|
||||
from pyinfra.queue.manager import QueueManager
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
|
||||
|
||||
settings = load_settings(local_pyinfra_root_path / "config/")
|
||||
|
||||
|
||||
def upload_json_and_make_message_body():
|
||||
dossier_id, file_id, suffix = "dossier", "file", "json.gz"
|
||||
content = {
|
||||
"numberOfPages": 7,
|
||||
"sectionTexts": "data",
|
||||
}
|
||||
|
||||
object_name = f"{dossier_id}/{file_id}.{suffix}"
|
||||
data = gzip.compress(json.dumps(content).encode("utf-8"))
|
||||
|
||||
storage = get_s3_storage_from_settings(settings)
|
||||
if not storage.has_bucket():
|
||||
storage.make_bucket()
|
||||
storage.put_object(object_name, data)
|
||||
|
||||
message_body = {
|
||||
"dossierId": dossier_id,
|
||||
"fileId": file_id,
|
||||
"targetFileExtension": suffix,
|
||||
"responseFileExtension": f"result.{suffix}",
|
||||
}
|
||||
return message_body
|
||||
|
||||
|
||||
def main():
|
||||
queue_manager = QueueManager(settings)
|
||||
queue_manager.purge_queues()
|
||||
|
||||
message = upload_json_and_make_message_body()
|
||||
|
||||
queue_manager.publish_message_to_input_queue(message)
|
||||
logger.info(f"Put {message} on {settings.rabbitmq.input_queue}.")
|
||||
|
||||
storage = get_s3_storage_from_settings(settings)
|
||||
for method_frame, properties, body in queue_manager.channel.consume(
|
||||
queue=settings.rabbitmq.output_queue, inactivity_timeout=15
|
||||
):
|
||||
if not body:
|
||||
break
|
||||
response = json.loads(body)
|
||||
logger.info(f"Received {response}")
|
||||
logger.info(f"Message headers: {properties.headers}")
|
||||
queue_manager.channel.basic_ack(method_frame.delivery_tag)
|
||||
dossier_id, file_id = itemgetter("dossierId", "fileId")(response)
|
||||
suffix = message["responseFileExtension"]
|
||||
print(f"{dossier_id}/{file_id}.{suffix}")
|
||||
result = storage.get_object(f"{dossier_id}/{file_id}.{suffix}")
|
||||
result = json.loads(gzip.decompress(result))
|
||||
logger.info(f"Contents of result on storage: {result}")
|
||||
queue_manager.stop_consuming()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,17 +0,0 @@
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
|
||||
# BE CAREFUL WITH THIS SCRIPT - THIS SIMULATES A SIGTERM FROM KUBERNETES
|
||||
target_pid = int(input("Enter the PID of the target script: "))
|
||||
|
||||
print(f"Sending SIGTERM to PID {target_pid}...")
|
||||
time.sleep(1)
|
||||
|
||||
try:
|
||||
os.kill(target_pid, signal.SIGTERM)
|
||||
print("SIGTERM sent.")
|
||||
except ProcessLookupError:
|
||||
print("Process not found.")
|
||||
except PermissionError:
|
||||
print("Permission denied. Are you trying to signal a process you don't own?")
|
||||
@ -1,39 +0,0 @@
|
||||
#!/bin/bash
|
||||
python_version=$1
|
||||
nexus_user=$2
|
||||
nexus_password=$3
|
||||
|
||||
# cookiecutter https://gitlab.knecon.com/knecon/research/template-python-project.git --checkout master
|
||||
# latest_dir=$(ls -td -- */ | head -n 1) # should be the dir cookiecutter just created
|
||||
|
||||
# cd $latest_dir
|
||||
|
||||
pyenv install $python_version
|
||||
pyenv local $python_version
|
||||
pyenv shell $python_version
|
||||
|
||||
# install poetry globally (PREFERRED), only need to install it once
|
||||
# curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# remember to update poetry once in a while
|
||||
poetry self update
|
||||
|
||||
# install poetry in current python environment, can lead to multiple instances of poetry being installed on one system (DISPREFERRED)
|
||||
# pip install --upgrade pip
|
||||
# pip install poetry
|
||||
|
||||
poetry config virtualenvs.in-project true
|
||||
poetry config installer.max-workers 10
|
||||
poetry config repositories.pypi-proxy "https://nexus.knecon.com/repository/pypi-proxy/simple"
|
||||
poetry config http-basic.pypi-proxy ${nexus_user} ${nexus_password}
|
||||
poetry config repositories.nexus https://nexus.knecon.com/repository/python/simple
|
||||
poetry config http-basic.nexus ${nexus_user} ${nexus_password}
|
||||
|
||||
poetry env use $(pyenv which python)
|
||||
poetry install --with=dev
|
||||
poetry update
|
||||
|
||||
source .venv/bin/activate
|
||||
|
||||
pre-commit install
|
||||
pre-commit autoupdate
|
||||
@ -1,18 +0,0 @@
|
||||
import time
|
||||
|
||||
from pyinfra.config.loader import load_settings, parse_settings_path
|
||||
from pyinfra.examples import start_standard_queue_consumer
|
||||
from pyinfra.queue.callback import make_download_process_upload_callback
|
||||
|
||||
|
||||
def processor_mock(_data: dict, _message: dict) -> dict:
|
||||
time.sleep(5)
|
||||
return {"result1": "result1"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
arguments = parse_settings_path()
|
||||
settings = load_settings(arguments)
|
||||
|
||||
callback = make_download_process_upload_callback(processor_mock, settings)
|
||||
start_standard_queue_consumer(callback, settings)
|
||||
13
setup.py
Executable file
13
setup.py
Executable file
@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from distutils.core import setup
|
||||
|
||||
setup(
|
||||
name="pyinfra",
|
||||
version="0.0.1",
|
||||
description="",
|
||||
author="",
|
||||
author_email="",
|
||||
url="",
|
||||
packages=["pyinfra"],
|
||||
)
|
||||
4
sonar-project.properties
Normal file
4
sonar-project.properties
Normal file
@ -0,0 +1,4 @@
|
||||
sonar.exclusions=bamboo-specs/**, build_venv/**
|
||||
sonar.c.file.suffixes=-
|
||||
sonar.cpp.file.suffixes=-
|
||||
sonar.objc.file.suffixes=-
|
||||
@ -1,48 +1,19 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from pyinfra.config import get_config, Config
|
||||
import os
|
||||
|
||||
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
|
||||
from pyinfra.queue.manager import QueueManager
|
||||
from pyinfra.storage.connection import get_storage
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def settings():
|
||||
return load_settings(local_pyinfra_root_path / "config/")
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def storage(storage_backend, settings):
|
||||
settings.storage.backend = storage_backend
|
||||
|
||||
storage = get_storage(settings)
|
||||
storage.make_bucket()
|
||||
|
||||
yield storage
|
||||
storage.clear_bucket()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def queue_manager(settings):
|
||||
settings.rabbitmq_heartbeat = 10
|
||||
settings.connection_sleep = 5
|
||||
settings.rabbitmq.max_retries = 3
|
||||
settings.rabbitmq.max_delay = 10
|
||||
queue_manager = QueueManager(settings)
|
||||
yield queue_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def input_message():
|
||||
return json.dumps(
|
||||
{
|
||||
"targetFilePath": "test/target.json.gz",
|
||||
"responseFilePath": "test/response.json.gz",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stop_message():
|
||||
return "STOP"
|
||||
@pytest.fixture(params=["aws", "azure"])
|
||||
def storage_config(request) -> Config:
|
||||
if request.param == "aws":
|
||||
os.environ["STORAGE_BACKEND"] = "s3"
|
||||
os.environ["STORAGE_BUCKET_NAME"] = "pyinfra-test-bucket"
|
||||
os.environ["STORAGE_ENDPOINT"] = "https://s3.amazonaws.com"
|
||||
os.environ["STORAGE_KEY"] = "AKIA4QVP6D4LCDAGYGN2"
|
||||
os.environ["STORAGE_SECRET"] = "8N6H1TUHTsbvW2qMAm7zZlJ63hMqjcXAsdN7TYED"
|
||||
os.environ["STORAGE_REGION"] = "eu-west-1"
|
||||
else:
|
||||
os.environ["STORAGE_BACKEND"] = "azure"
|
||||
os.environ["STORAGE_AZURECONTAINERNAME"] = "pyinfra-test-bucket"
|
||||
os.environ["STORAGE_AZURECONNECTIONSTRING"] = "DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net"
|
||||
|
||||
return get_config()
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
outs:
|
||||
- md5: 75cc98b7c8fcf782a7d4941594e6bc12.dir
|
||||
size: 134913
|
||||
nfiles: 9
|
||||
hash: md5
|
||||
path: data
|
||||
@ -1,41 +0,0 @@
|
||||
version: '3.8'
|
||||
services:
|
||||
minio:
|
||||
image: minio/minio:latest
|
||||
container_name: minio
|
||||
ports:
|
||||
- "9000:9000"
|
||||
environment:
|
||||
- MINIO_ROOT_PASSWORD=password
|
||||
- MINIO_ROOT_USER=root
|
||||
volumes:
|
||||
- /tmp/data/minio_store:/data
|
||||
command: server /data
|
||||
network_mode: "bridge"
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
rabbitmq:
|
||||
image: docker.io/bitnami/rabbitmq:latest
|
||||
container_name: rabbitmq
|
||||
ports:
|
||||
# - '4369:4369'
|
||||
# - '5551:5551'
|
||||
# - '5552:5552'
|
||||
- '5672:5672'
|
||||
- '15672:15672'
|
||||
# - '25672:25672'
|
||||
environment:
|
||||
- RABBITMQ_SECURE_PASSWORD=yes
|
||||
- RABBITMQ_VM_MEMORY_HIGH_WATERMARK=100%
|
||||
- RABBITMQ_DISK_FREE_ABSOLUTE_LIMIT=20Gi
|
||||
- RABBITMQ_MANAGEMENT_ALLOW_WEB_ACCESS=true
|
||||
network_mode: "bridge"
|
||||
volumes:
|
||||
- /tmp/bitnami/rabbitmq/.rabbitmq/:/data/bitnami
|
||||
healthcheck:
|
||||
test: [ "CMD", "curl", "-f", "http://localhost:15672" ]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
@ -1,41 +0,0 @@
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
|
||||
from pyinfra.utils.opentelemetry import get_exporter, instrument_pika, setup_trace
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def exporter(settings):
|
||||
settings.tracing.opentelemetry.exporter = "json"
|
||||
return get_exporter(settings)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_trace(settings, exporter, tracing_type):
|
||||
settings.tracing.type = tracing_type
|
||||
setup_trace(settings, exporter=exporter)
|
||||
|
||||
|
||||
class TestOpenTelemetry:
|
||||
@pytest.mark.xfail(
|
||||
reason="Azure Monitor requires a connection string. Therefore the test is allowed to fail in this case."
|
||||
)
|
||||
@pytest.mark.parametrize("tracing_type", ["opentelemetry", "azure_monitor"])
|
||||
def test_queue_messages_are_traced(self, queue_manager, input_message, stop_message, settings, exporter):
|
||||
instrument_pika()
|
||||
|
||||
queue_manager.purge_queues()
|
||||
queue_manager.publish_message_to_input_queue(input_message)
|
||||
queue_manager.publish_message_to_input_queue(stop_message)
|
||||
|
||||
def callback(_):
|
||||
sleep(2)
|
||||
return {"flat": "earth"}
|
||||
|
||||
queue_manager.start_consuming(callback)
|
||||
|
||||
for exported_trace in exporter.traces:
|
||||
assert (
|
||||
exported_trace["resource"]["attributes"]["service.name"] == settings.tracing.opentelemetry.service_name
|
||||
)
|
||||
@ -1,55 +0,0 @@
|
||||
import re
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from fastapi import FastAPI
|
||||
|
||||
from pyinfra.webserver.prometheus import (
|
||||
add_prometheus_endpoint,
|
||||
make_prometheus_processing_time_decorator_from_settings,
|
||||
)
|
||||
from pyinfra.webserver.utils import create_webserver_thread_from_settings
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def app_with_prometheus_endpoint(settings):
|
||||
app = FastAPI()
|
||||
app = add_prometheus_endpoint(app)
|
||||
thread = create_webserver_thread_from_settings(app, settings)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
sleep(1)
|
||||
yield
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def monitored_function(settings):
|
||||
@make_prometheus_processing_time_decorator_from_settings(settings)
|
||||
def process(*args, **kwargs):
|
||||
sleep(0.5)
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class TestPrometheusMonitor:
|
||||
def test_prometheus_endpoint_is_available(self, app_with_prometheus_endpoint, settings):
|
||||
resp = requests.get(f"http://{settings.webserver.host}:{settings.webserver.port}/prometheus")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_processing_with_a_monitored_fn_increases_parameter_counter(
|
||||
self, app_with_prometheus_endpoint, monitored_function, settings
|
||||
):
|
||||
pattern = re.compile(rf".*{settings.metrics.prometheus.prefix}_processing_time_count (\d\.\d).*")
|
||||
|
||||
resp = requests.get(f"http://{settings.webserver.host}:{settings.webserver.port}/prometheus")
|
||||
assert pattern.search(resp.text).group(1) == "0.0"
|
||||
|
||||
monitored_function()
|
||||
resp = requests.get(f"http://{settings.webserver.host}:{settings.webserver.port}/prometheus")
|
||||
assert pattern.search(resp.text).group(1) == "1.0"
|
||||
|
||||
monitored_function()
|
||||
resp = requests.get(f"http://{settings.webserver.host}:{settings.webserver.port}/prometheus")
|
||||
assert pattern.search(resp.text).group(1) == "2.0"
|
||||
@ -1,90 +0,0 @@
|
||||
import json
|
||||
from sys import stdout
|
||||
from time import sleep
|
||||
|
||||
import pika
|
||||
from kn_utils.logging import logger
|
||||
|
||||
logger.remove()
|
||||
logger.add(sink=stdout, level="DEBUG")
|
||||
|
||||
|
||||
def make_callback(process_time):
|
||||
def callback(x):
|
||||
sleep(process_time)
|
||||
return {"status": "success"}
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
def file_not_found_callback(x):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
|
||||
class TestQueueManager:
|
||||
def test_not_available_file_leads_to_message_rejection_without_crashing(
|
||||
self, queue_manager, input_message, stop_message
|
||||
):
|
||||
queue_manager.purge_queues()
|
||||
|
||||
queue_manager.publish_message_to_input_queue(input_message)
|
||||
queue_manager.publish_message_to_input_queue(stop_message)
|
||||
|
||||
queue_manager.start_consuming(file_not_found_callback)
|
||||
|
||||
def test_processing_of_several_messages(self, queue_manager, input_message, stop_message):
|
||||
queue_manager.purge_queues()
|
||||
|
||||
for _ in range(2):
|
||||
queue_manager.publish_message_to_input_queue(input_message)
|
||||
|
||||
queue_manager.publish_message_to_input_queue(stop_message)
|
||||
|
||||
callback = make_callback(1)
|
||||
queue_manager.start_consuming(callback)
|
||||
|
||||
for _ in range(2):
|
||||
response = queue_manager.get_message_from_output_queue()
|
||||
assert response is not None
|
||||
assert json.loads(response[2].decode()) == {"status": "success"}
|
||||
|
||||
def test_all_headers_beginning_with_x_are_forwarded(self, queue_manager, input_message, stop_message):
|
||||
queue_manager.purge_queues()
|
||||
|
||||
properties = pika.BasicProperties(
|
||||
headers={
|
||||
"X-TENANT-ID": "redaction",
|
||||
"X-OTHER-HEADER": "other-header-value",
|
||||
"x-tenant_id": "tenant-id-value",
|
||||
"x_should_not_be_forwarded": "should-not-be-forwarded-value",
|
||||
}
|
||||
)
|
||||
|
||||
queue_manager.publish_message_to_input_queue(input_message, properties=properties)
|
||||
queue_manager.publish_message_to_input_queue(stop_message)
|
||||
|
||||
callback = make_callback(0.2)
|
||||
queue_manager.start_consuming(callback)
|
||||
|
||||
response = queue_manager.get_message_from_output_queue()
|
||||
|
||||
assert json.loads(response[2].decode()) == {"status": "success"}
|
||||
|
||||
assert response[1].headers["X-TENANT-ID"] == "redaction"
|
||||
assert response[1].headers["X-OTHER-HEADER"] == "other-header-value"
|
||||
assert response[1].headers["x-tenant_id"] == "tenant-id-value"
|
||||
|
||||
assert "x_should_not_be_forwarded" not in response[1].headers
|
||||
|
||||
def test_message_processing_does_not_block_heartbeat(self, queue_manager, input_message, stop_message):
|
||||
queue_manager.purge_queues()
|
||||
|
||||
queue_manager.publish_message_to_input_queue(input_message)
|
||||
queue_manager.publish_message_to_input_queue(stop_message)
|
||||
|
||||
callback = make_callback(15)
|
||||
queue_manager.start_consuming(callback)
|
||||
|
||||
response = queue_manager.get_message_from_output_queue()
|
||||
|
||||
assert json.loads(response[2].decode()) == {"status": "success"}
|
||||
@ -1,166 +0,0 @@
|
||||
import gzip
|
||||
import json
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
from pyinfra.storage.connection import get_storage_for_tenant
|
||||
from pyinfra.storage.utils import (
|
||||
download_data_bytes_as_specified_in_message,
|
||||
upload_data_as_specified_in_message,
|
||||
)
|
||||
from pyinfra.utils.cipher import encrypt
|
||||
from pyinfra.webserver.utils import create_webserver_thread
|
||||
|
||||
|
||||
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
|
||||
class TestStorage:
|
||||
def test_clearing_bucket_yields_empty_bucket(self, storage):
|
||||
storage.clear_bucket()
|
||||
data_received = storage.get_all_objects()
|
||||
assert not {*data_received}
|
||||
|
||||
def test_getting_object_put_in_bucket_is_object(self, storage):
|
||||
storage.clear_bucket()
|
||||
storage.put_object("file", b"content")
|
||||
data_received = storage.get_object("file")
|
||||
assert b"content" == data_received
|
||||
|
||||
def test_object_put_in_bucket_exists_on_storage(self, storage):
|
||||
storage.clear_bucket()
|
||||
storage.put_object("file", b"content")
|
||||
assert storage.exists("file")
|
||||
|
||||
def test_getting_nested_object_put_in_bucket_is_nested_object(self, storage):
|
||||
storage.clear_bucket()
|
||||
storage.put_object("folder/file", b"content")
|
||||
data_received = storage.get_object("folder/file")
|
||||
assert b"content" == data_received
|
||||
|
||||
def test_getting_objects_put_in_bucket_are_objects(self, storage):
|
||||
storage.clear_bucket()
|
||||
storage.put_object("file1", b"content 1")
|
||||
storage.put_object("folder/file2", b"content 2")
|
||||
data_received = storage.get_all_objects()
|
||||
assert {b"content 1", b"content 2"} == {*data_received}
|
||||
|
||||
def test_make_bucket_produces_bucket(self, storage):
|
||||
storage.clear_bucket()
|
||||
storage.make_bucket()
|
||||
assert storage.has_bucket()
|
||||
|
||||
def test_listing_bucket_files_yields_all_files_in_bucket(self, storage):
|
||||
storage.clear_bucket()
|
||||
storage.put_object("file1", b"content 1")
|
||||
storage.put_object("file2", b"content 2")
|
||||
full_names_received = storage.get_all_object_names()
|
||||
assert {(storage.bucket, "file1"), (storage.bucket, "file2")} == {*full_names_received}
|
||||
|
||||
def test_data_loading_failure_raised_if_object_not_present(self, storage):
|
||||
storage.clear_bucket()
|
||||
with pytest.raises(Exception):
|
||||
storage.get_object("folder/file")
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def tenant_server_mock(settings, tenant_server_host, tenant_server_port):
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/azure_tenant")
|
||||
def get_azure_storage_info():
|
||||
return {
|
||||
"azureStorageConnection": {
|
||||
"connectionString": encrypt(
|
||||
settings.storage.tenant_server.public_key, settings.storage.azure.connection_string
|
||||
),
|
||||
"containerName": settings.storage.azure.container,
|
||||
}
|
||||
}
|
||||
|
||||
@app.get("/s3_tenant")
|
||||
def get_s3_storage_info():
|
||||
return {
|
||||
"s3StorageConnection": {
|
||||
"endpoint": settings.storage.s3.endpoint,
|
||||
"key": settings.storage.s3.key,
|
||||
"secret": encrypt(settings.storage.tenant_server.public_key, settings.storage.s3.secret),
|
||||
"region": settings.storage.s3.region,
|
||||
"bucketName": settings.storage.s3.bucket,
|
||||
}
|
||||
}
|
||||
|
||||
thread = create_webserver_thread(app, tenant_server_port, tenant_server_host)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
sleep(1)
|
||||
yield
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tenant_id", ["azure_tenant", "s3_tenant"], scope="class")
|
||||
@pytest.mark.parametrize("tenant_server_host", ["localhost"], scope="class")
|
||||
@pytest.mark.parametrize("tenant_server_port", [8000], scope="class")
|
||||
class TestMultiTenantStorage:
|
||||
def test_storage_connection_from_tenant_id(
|
||||
self, tenant_id, tenant_server_mock, settings, tenant_server_host, tenant_server_port
|
||||
):
|
||||
settings["storage"]["tenant_server"]["endpoint"] = f"http://{tenant_server_host}:{tenant_server_port}"
|
||||
storage = get_storage_for_tenant(
|
||||
tenant_id,
|
||||
settings["storage"]["tenant_server"]["endpoint"],
|
||||
settings["storage"]["tenant_server"]["public_key"],
|
||||
)
|
||||
|
||||
storage.put_object("file", b"content")
|
||||
data_received = storage.get_object("file")
|
||||
|
||||
assert b"content" == data_received
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload(payload_type):
|
||||
if payload_type == "target_response_file_path":
|
||||
return {
|
||||
"targetFilePath": "test/file.target.json.gz",
|
||||
"responseFilePath": "test/file.response.json.gz",
|
||||
}
|
||||
elif payload_type == "dossier_id_file_id":
|
||||
return {
|
||||
"dossierId": "test",
|
||||
"fileId": "file",
|
||||
"targetFileExtension": "target.json.gz",
|
||||
"responseFileExtension": "response.json.gz",
|
||||
}
|
||||
elif payload_type == "target_file_dict":
|
||||
return {
|
||||
"targetFilePath": {"file_1": "test/file.target.json.gz", "file_2": "test/file.target.json.gz"},
|
||||
"responseFilePath": "test/file.response.json.gz",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload_type",
|
||||
[
|
||||
"target_response_file_path",
|
||||
"dossier_id_file_id",
|
||||
"target_file_dict",
|
||||
],
|
||||
scope="class",
|
||||
)
|
||||
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
|
||||
class TestDownloadAndUploadFromMessage:
|
||||
def test_download_and_upload_from_message(self, storage, payload, payload_type):
|
||||
storage.clear_bucket()
|
||||
|
||||
result = {"process_result": "success"}
|
||||
storage_data = {**payload, "data": result}
|
||||
packed_data = gzip.compress(json.dumps(storage_data).encode())
|
||||
|
||||
storage.put_object("test/file.target.json.gz", packed_data)
|
||||
|
||||
_ = download_data_bytes_as_specified_in_message(storage, payload)
|
||||
upload_data_as_specified_in_message(storage, payload, result)
|
||||
data = json.loads(gzip.decompress(storage.get_object("test/file.response.json.gz")).decode())
|
||||
|
||||
assert data == storage_data
|
||||
5
tests/test_storage.py
Normal file
5
tests/test_storage.py
Normal file
@ -0,0 +1,5 @@
|
||||
from pyinfra.storage import get_storage
|
||||
|
||||
def test_storage(storage_config) -> None:
|
||||
storage = get_storage(storage_config)
|
||||
assert storage.has_bucket(storage_config.storage_bucket)
|
||||
@ -1,29 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from pyinfra.utils.cipher import decrypt, encrypt
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ciphertext():
|
||||
return "AAAADBRzag4/aAE2+rSekyI5phVZ1e0wwSaRkGQTLftPyVvq8vLYZzwxW48Wozc3/w=="
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def plaintext():
|
||||
return "connectzionString"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def public_key():
|
||||
return "redaction"
|
||||
|
||||
|
||||
class TestDecryption:
|
||||
def test_decrypt_ciphertext(self, public_key, ciphertext, plaintext):
|
||||
result = decrypt(public_key, ciphertext)
|
||||
assert result == plaintext
|
||||
|
||||
def test_encrypt_plaintext(self, public_key, plaintext):
|
||||
ciphertext = encrypt(public_key, plaintext)
|
||||
result = decrypt(public_key, ciphertext)
|
||||
assert plaintext == result
|
||||
@ -1,55 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from dynaconf import Validator
|
||||
|
||||
from pyinfra.config.loader import load_settings, local_pyinfra_root_path, normalize_to_settings_files
|
||||
from pyinfra.config.validators import webserver_validators
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_validators():
|
||||
return [
|
||||
Validator("test.value.int", must_exist=True, is_type_of=int),
|
||||
Validator("test.value.str", must_exist=True, is_type_of=str),
|
||||
]
|
||||
|
||||
|
||||
class TestConfig:
|
||||
def test_config_validation(self):
|
||||
os.environ["WEBSERVER__HOST"] = "localhost"
|
||||
os.environ["WEBSERVER__PORT"] = "8080"
|
||||
|
||||
validators = webserver_validators
|
||||
|
||||
test_settings = load_settings(root_path=local_pyinfra_root_path, validators=validators)
|
||||
|
||||
assert test_settings.webserver.host == "localhost"
|
||||
|
||||
def test_env_into_correct_type_conversion(self, test_validators):
|
||||
os.environ["TEST__VALUE__INT"] = "1"
|
||||
os.environ["TEST__VALUE__STR"] = "test"
|
||||
|
||||
test_settings = load_settings(root_path=local_pyinfra_root_path, validators=test_validators)
|
||||
|
||||
assert test_settings.test.value.int == 1
|
||||
assert test_settings.test.value.str == "test"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"settings_path,expected_file_paths",
|
||||
[
|
||||
(None, []),
|
||||
("config", [f"{local_pyinfra_root_path}/config/settings.toml"]),
|
||||
("config/settings.toml", [f"{local_pyinfra_root_path}/config/settings.toml"]),
|
||||
(f"{local_pyinfra_root_path}/config", [f"{local_pyinfra_root_path}/config/settings.toml"]),
|
||||
],
|
||||
)
|
||||
def test_normalize_settings_files(self, settings_path, expected_file_paths):
|
||||
files = normalize_to_settings_files(settings_path, local_pyinfra_root_path)
|
||||
print(files)
|
||||
|
||||
assert len(files) == len(expected_file_paths)
|
||||
|
||||
for path, expected in zip(files, expected_file_paths):
|
||||
assert path == Path(expected).absolute()
|
||||
@ -1,19 +0,0 @@
|
||||
import pytest
|
||||
from kn_utils.logging import logger
|
||||
|
||||
|
||||
def test_necessary_log_levels_are_supported_by_kn_utils():
|
||||
logger.setLevel("TRACE")
|
||||
|
||||
logger.trace("trace")
|
||||
logger.debug("debug")
|
||||
logger.info("info")
|
||||
logger.warning("warning")
|
||||
logger.critical("critical")
|
||||
logger.exception("exception", exc_info="this is an exception")
|
||||
logger.error("error", exc_info="this is an error")
|
||||
|
||||
|
||||
def test_setlevel_warn():
|
||||
logger.setLevel("WARN")
|
||||
logger.warning("warn")
|
||||
@ -1,83 +0,0 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from pyinfra.storage.utils import (
|
||||
download_data_bytes_as_specified_in_message,
|
||||
upload_data_as_specified_in_message,
|
||||
DownloadedData,
|
||||
)
|
||||
from pyinfra.storage.storages.storage import Storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage():
|
||||
with patch("pyinfra.storage.utils.Storage") as MockStorage:
|
||||
yield MockStorage()
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
{
|
||||
"raw_payload": {
|
||||
"tenantId": "tenant1",
|
||||
"dossierId": "dossier1",
|
||||
"fileId": "file1",
|
||||
"targetFileExtension": "txt",
|
||||
"responseFileExtension": "json",
|
||||
},
|
||||
"expected_result": {
|
||||
"data": b'{"key": "value"}',
|
||||
"file_path": "tenant1/dossier1/file1.txt"
|
||||
}
|
||||
},
|
||||
{
|
||||
"raw_payload": {
|
||||
"targetFilePath": "some/path/to/file.txt.gz",
|
||||
"responseFilePath": "some/path/to/file.json"
|
||||
},
|
||||
"expected_result": {
|
||||
"data": b'{"key": "value"}',
|
||||
"file_path": "some/path/to/file.txt.gz"
|
||||
}
|
||||
},
|
||||
{
|
||||
"raw_payload": {
|
||||
"targetFilePath": {
|
||||
"file1": "some/path/to/file1.txt.gz",
|
||||
"file2": "some/path/to/file2.txt.gz"
|
||||
},
|
||||
"responseFilePath": "some/path/to/file.json"
|
||||
},
|
||||
"expected_result": {
|
||||
"file1": {
|
||||
"data": b'{"key": "value"}',
|
||||
"file_path": "some/path/to/file1.txt.gz"
|
||||
},
|
||||
"file2": {
|
||||
"data": b'{"key": "value"}',
|
||||
"file_path": "some/path/to/file2.txt.gz"
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
)
|
||||
def payload_and_expected_result(request):
|
||||
return request.param
|
||||
|
||||
def test_download_data_bytes_as_specified_in_message(mock_storage, payload_and_expected_result):
|
||||
raw_payload = payload_and_expected_result["raw_payload"]
|
||||
expected_result = payload_and_expected_result["expected_result"]
|
||||
mock_storage.get_object.return_value = b'{"key": "value"}'
|
||||
|
||||
result = download_data_bytes_as_specified_in_message(mock_storage, raw_payload)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result == expected_result
|
||||
mock_storage.get_object.assert_called()
|
||||
|
||||
def test_upload_data_as_specified_in_message(mock_storage, payload_and_expected_result):
|
||||
raw_payload = payload_and_expected_result["raw_payload"]
|
||||
data = {"key": "value"}
|
||||
upload_data_as_specified_in_message(mock_storage, raw_payload, data)
|
||||
mock_storage.put_object.assert_called_once()
|
||||
Loading…
x
Reference in New Issue
Block a user