Compare commits

...

62 Commits

Author SHA1 Message Date
Julius Unverfehrt
3ef4246d1e chore: fuzzy pin kn-utils to allow for future updates 2025-01-22 12:36:38 +01:00
Julius Unverfehrt
841c492639 Merge branch 'chore/RES-871-update-callback' into 'master'
feat:BREAKING CHANGE: download callback no forwards all files as bytes

See merge request knecon/research/pyinfra!108
2025-01-16 11:11:59 +01:00
Julius Unverfehrt
ead069d3a7 chore: adjust docstrings 2025-01-16 10:35:06 +01:00
Julius Unverfehrt
044ea6cf0a feat: streamline download to always include the filename of the downloaded file 2025-01-16 10:29:50 +01:00
Julius Unverfehrt
ff7547e2c6 fix: remove faulty import 2025-01-16 10:29:50 +01:00
Julius Unverfehrt
fbf79ef758 chore: regenerate BOM 2025-01-16 10:29:50 +01:00
Julius Unverfehrt
f382887d40 chore: seek and destroy proto in code 2025-01-16 10:29:50 +01:00
Julius Unverfehrt
5c4400aa8b feat:BREAKING CHANGE: download callback no forwards all files as bytes 2025-01-16 10:29:46 +01:00
Jonathan Kössler
5ce66f18a0 Merge branch 'bugfix/RED-10722' into 'master'
fix: dlq init

See merge request knecon/research/pyinfra!109
2025-01-15 10:56:12 +01:00
Jonathan Kössler
ea0c55930a chore: remove test nack 2025-01-15 10:00:50 +01:00
Jonathan Kössler
87f57e2244 fix: dlq init 2025-01-14 16:39:47 +01:00
Jonathan Kössler
3fb8c4e641 fix: do not use groups for packages 2024-12-18 16:33:35 +01:00
Jonathan Kössler
e23f63acf0 Merge branch 'chore/nexus-package-registry' into 'master'
RES-914: move package registry to nexus

See merge request knecon/research/pyinfra!106
2024-11-20 10:02:52 +01:00
Jonathan Kössler
d3fecc518e chore: move integration tests to own subfolder 2024-11-18 17:31:15 +01:00
Jonathan Kössler
341500d463 chore: set lower bound for opentelemetry dependencies 2024-11-18 17:28:11 +01:00
Jonathan Kössler
e002f77fd5 Revert "chore: update opentelemetry for proto v5 support"
This reverts commit 3c6d8f2dcc73b17f329f9cecb8d4d301f848dc1e.
2024-11-18 17:19:37 +01:00
Jonathan Kössler
3c6d8f2dcc chore: update opentelemetry for proto v5 support 2024-11-18 15:14:34 +01:00
Jonathan Kössler
f6d6ba40bb chore: add pytest-cov 2024-11-18 13:57:39 +01:00
Jonathan Kössler
6a0bbad108 ops: update CI 2024-11-18 13:53:11 +01:00
Jonathan Kössler
527a671a75 feat: move package registry to nexus 2024-11-18 13:49:48 +01:00
Jonathan Kössler
cf91189728 Merge branch 'feature/RED-10441' into 'master'
RED-10441: separate queue and webserver shutdown

See merge request knecon/research/pyinfra!105
2024-11-13 17:17:13 +01:00
Jonathan Kössler
61a6d0eeed feat: separate queue and webserver shutdown 2024-11-13 17:02:21 +01:00
Jonathan Kössler
bc0b355ff9 Merge branch 'feature/RED-10441' into 'master'
RED-10441: ensure queue manager shutdown

See merge request knecon/research/pyinfra!104
2024-11-13 16:34:25 +01:00
Jonathan Kössler
235e27b74c chore: bump version 2024-11-13 16:31:48 +01:00
Jonathan Kössler
1540c2894e feat: ensure shutdown of queue manager 2024-11-13 16:30:18 +01:00
Jonathan Kössler
9b60594ce1 Merge branch 'feature/RED-10441' into 'master'
RED-10441: Fix graceful shutdown

See merge request knecon/research/pyinfra!103
2024-11-13 14:48:34 +01:00
Jonathan Kössler
3d3c76b466 chore: bump version 2024-11-13 13:55:15 +01:00
Jonathan Kössler
9d4ec84b49 fix: use signals for graceful shutdown 2024-11-13 13:54:41 +01:00
Jonathan Kössler
8891249d7a Merge branch 'feature/RED-10441' into 'master'
RED-10441: fix abandoned queues

See merge request knecon/research/pyinfra!102
2024-11-13 09:35:36 +01:00
Jonathan Kössler
e51e5c33eb chore: cleanup 2024-11-12 17:24:57 +01:00
Jonathan Kössler
04c90533b6 refactor: fetch active tenants before start 2024-11-12 17:11:33 +01:00
Jonathan Kössler
86af05c12c feat: add logger to retry 2024-11-12 16:50:23 +01:00
Jonathan Kössler
c6e336cb35 refactor: tenant queues init 2024-11-12 15:55:11 +01:00
Jonathan Kössler
bf6f95f3e0 feat: exit on ClientResponseError 2024-11-12 15:32:11 +01:00
Jonathan Kössler
ed2bd1ec86 refactor: raise error if tenant service is not available 2024-11-12 13:30:21 +01:00
Julius Unverfehrt
9906f68e0a chore: bumb versions to enable package rebuild (current package has the wrong hash due to backup issues) 2024-11-11 12:47:27 +01:00
Julius Unverfehrt
0af648d66c fix: rebuild since mia and update rebuild kn_utils 2024-11-08 13:52:08 +01:00
Jonathan Kössler
46dc1fdce4 Merge branch 'feature/RES-809' into 'master'
RES-809: update kn_utils

See merge request knecon/research/pyinfra!101
2024-10-23 18:01:25 +02:00
Jonathan Kössler
bd2f0b9b9a feat: switch out tenacity retry with kn_utils 2024-10-23 16:06:06 +02:00
Jonathan Kössler
131afd7d3e chore: update kn_utils 2024-10-23 16:04:08 +02:00
Jonathan Kössler
98532c60ed Merge branch 'feature/RES-858-fix-graceful-shutdown' into 'master'
RES-858: fix graceful shutdown for unexpected broker disconnects

See merge request knecon/research/pyinfra!100
2024-09-30 09:54:25 +02:00
Jonathan Kössler
45377ba172 feat: improve on close callback and simplify exception handling 2024-09-27 17:11:10 +02:00
Jonathan Kössler
f855224e29 feat: add on close callback 2024-09-27 10:00:41 +02:00
Jonathan Kössler
541219177f feat: add error handling to shutdown logic 2024-09-26 12:28:55 +02:00
Jonathan Kössler
4119a7d7d7 chore: bump version 2024-09-26 11:05:12 +02:00
Jonathan Kössler
e2edfa7260 fix: simplify webserver shutdown 2024-09-26 10:33:05 +02:00
Jonathan Kössler
b70b16c541 Merge branch 'feature/RES-856-test-proto-format' into 'master'
RES-856: add type tests for proto format

See merge request knecon/research/pyinfra!99
2024-09-26 10:07:29 +02:00
Jonathan Kössler
e8d9326e48 chore: rewrite lock and bump version 2024-09-26 09:45:42 +02:00
Jonathan Kössler
9669152e14 Merge branch 'master' into feature/RES-856-test-proto-format 2024-09-26 09:39:28 +02:00
Jonathan Kössler
ed3f8088e1 Merge branch 'feature/RES-844-fix-tracing' into 'master'
RES-844: fix opentelemtry tracing

See merge request knecon/research/pyinfra!98
2024-09-26 09:13:52 +02:00
Jonathan Kössler
66eaa9a748 feat: set range for protobuf version 2024-09-25 14:16:40 +02:00
Jonathan Kössler
3a04359320 chore: bump pyinfra version 2024-09-25 11:59:52 +02:00
Jonathan Kössler
b46fcbd977 feat: add AioPikaInstrumentor 2024-09-25 11:58:51 +02:00
Jonathan Kössler
e75df42bec feat: skip keys in int conversion 2024-09-25 11:07:20 +02:00
Jonathan Kössler
3bab86fe83 chore: update test files 2024-09-24 11:59:08 +02:00
Jonathan Kössler
c5d53b8665 feat: add file comparison 2024-09-24 11:57:33 +02:00
Jonathan Kössler
09d39930e7 chore: cleanup test 2024-09-23 16:43:59 +02:00
Jonathan Kössler
a81f1bf31a chore: update protobuf to 25.5 2024-09-23 16:41:57 +02:00
Francisco Schulz
0783e95d22 Merge branch 'RED-10017-investigate-crashing-py-services-when-upload-large-number-of-files' into 'master'
fix: add semaphore to AsyncQueueManager to limit concurrent tasks

See merge request knecon/research/pyinfra!97
2024-09-23 15:19:40 +02:00
Francisco Schulz
8ec13502a9 fix: add semaphore to AsyncQueueManager to limit concurrent tasks 2024-09-23 15:19:40 +02:00
Jonathan Kössler
43881de526 feat: add tests for types of documentreader 2024-09-20 16:42:55 +02:00
Julius Unverfehrt
67c30a5620 fix: recompile proto schemas with experimental schema update 2024-09-20 15:23:13 +02:00
40 changed files with 10856 additions and 7289 deletions

View File

@ -1,49 +1,23 @@
# CI for services, check gitlab repo for python package CI
include:
- project: "Gitlab/gitlab"
ref: 0.3.0
file: "/ci-templates/research/python_pkg_venv_test_build_release_gitlab-ci.yml"
default:
image: python:3.10
ref: main
file: "/ci-templates/research/python_pkg-test-build-release.gitlab-ci.yml"
# set project variables here
variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
GITLAB_PYPI_URL: https://gitlab.knecon.com/api/v4/projects/${CI_PROJECT_ID}/packages/pypi
PYPI_REGISTRY_RESEARCH: https://gitlab.knecon.com/api/v4/groups/19/-/packages/pypi
POETRY_SOURCE_REF_RESEARCH: gitlab-research
PYPI_REGISTRY_RED: https://gitlab.knecon.com/api/v4/groups/12/-/packages/pypi
POETRY_SOURCE_REF_RED: gitlab-red
PYPI_REGISTRY_FFORESIGHT: https://gitlab.knecon.com/api/v4/groups/269/-/packages/pypi
POETRY_SOURCE_REF_FFORESIGHT: gitlab-fforesight
# POETRY_HOME: /opt/poetry
NEXUS_PROJECT_DIR: research # subfolder in Nexus docker-gin where your container will be stored
IMAGENAME: $CI_PROJECT_NAME # if the project URL is gitlab.example.com/group-name/project-1, CI_PROJECT_NAME is project-1
REPORTS_DIR: reports
FF_USE_FASTZIP: "true" # enable fastzip - a faster zip implementation that also supports level configuration.
ARTIFACT_COMPRESSION_LEVEL: default # can also be set to fastest, fast, slow and slowest. If just enabling fastzip is not enough try setting this to fastest or fast.
CACHE_COMPRESSION_LEVEL: default # same as above, but for caches
# TRANSFER_METER_FREQUENCY: 5s # will display transfer progress every 5 seconds for artifacts and remote caches. For debugging purposes.
setup-poetry-venv:
stage: setup
script:
- env # check env vars
# install poetry & return versions
- pip install --upgrade pip
- pip -V
- python -V
- pip install poetry
- poetry -V
# configure poetry
- poetry config installer.max-workers 10
- poetry config virtualenvs.in-project true
- poetry config repositories.${POETRY_SOURCE_REF_RESEARCH} ${PYPI_REGISTRY_RESEARCH}
- poetry config http-basic.${POETRY_SOURCE_REF_RESEARCH} ${CI_REGISTRY_USER} ${CI_JOB_TOKEN}
- poetry config repositories.${POETRY_SOURCE_REF_RED} ${PYPI_REGISTRY_RED}
- poetry config http-basic.${POETRY_SOURCE_REF_RED} ${CI_REGISTRY_USER} ${CI_JOB_TOKEN}
- poetry config repositories.${POETRY_SOURCE_REF_FFORESIGHT} ${PYPI_REGISTRY_FFORESIGHT}
- poetry config http-basic.${POETRY_SOURCE_REF_FFORESIGHT} ${CI_REGISTRY_USER} ${CI_JOB_TOKEN}
# create and activate venv
- poetry env use $(which python)
- source .venv/bin/activate
- python -m ensurepip
- env # check env vars again
# install from poetry.lock file
- poetry install --all-extras -vvv
run-tests:
script:
- echo "Disabled until we have an automated way to run docker compose before tests."
############
# UNIT TESTS
unit-tests:
variables:
###### UPDATE/EDIT ######
UNIT_TEST_DIR: "tests/unit_test"

View File

@ -5,7 +5,7 @@ default_language_version:
python: python3.10
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
@ -26,6 +26,7 @@ repos:
rev: v3.0.0a5
hooks:
- id: pylint
language: system
args:
- --disable=C0111,R0903
- --max-line-length=120
@ -38,7 +39,7 @@ repos:
- --profile black
- repo: https://github.com/psf/black
rev: 24.4.2
rev: 24.10.0
hooks:
- id: black
# exclude: ^(docs/|notebooks/|data/|src/secrets/)
@ -46,7 +47,7 @@ repos:
- --line-length=120
- repo: https://github.com/compilerla/conventional-pre-commit
rev: v3.2.0
rev: v3.6.0
hooks:
- id: conventional-pre-commit
pass_filenames: false

View File

@ -1 +1 @@
3.10.12
3.10

View File

@ -6,7 +6,7 @@
4. [ Module Installation ](#module-installation)
5. [ Scripts ](#scripts)
6. [ Tests ](#tests)
7. [ Protobuf ](#protobuf)
7. [ Opentelemetry protobuf dependency hell ](#opentelemetry-protobuf-dependency-hell)
## About
@ -73,6 +73,17 @@ the [complete example](pyinfra/examples.py).
| TRACING\_\_OPENTELEMETRY\_\_EXPORTER | tracing.opentelemetry.exporter | Name of exporter |
| KUBERNETES\_\_POD_NAME | kubernetes.pod_name | Service pod name |
## Setup
**IMPORTANT** you need to set the following environment variables before running the setup script:
- ``$NEXUS_USER`` your Nexus user (usually equal to firstname.lastname@knecon.com)
- ``$NEXUS_PASSWORD`` your Nexus password (usually equal to your Azure Login)
```shell
# create venv and activate it
source ./scripts/setup/devenvsetup.sh {{ cookiecutter.python_version }} $NEXUS_USER $NEXUS_PASSWORD
source .venv/bin/activate
```
### OpenTelemetry
Open telemetry (vis its Python SDK) is set up to be as unobtrusive as possible; for typical use cases it can be
@ -202,48 +213,8 @@ $ python scripts/send_request.py
Tests require a running minio and rabbitmq container, meaning you have to run `docker compose up` in the tests folder
before running the tests.
## Protobuf
## OpenTelemetry Protobuf Dependency Hell
### Opentelemetry Compatibility Issue
**Note**: Status: 31/07/2024, the currently used `opentelemetry-exporter-otlp-proto-http` version `1.25.0` requires
a `protobuf` version < `5.x.x` and is not compatible with the latest protobuf version `5.27.x`. This is an [open issue](https://github.com/open-telemetry/opentelemetry-python/issues/3958) in opentelemetry, because [support for 4.25.x ends in Q2 '25](https://protobuf.dev/support/version-support/#python). Therefore, we should keep this in mind and update the dependency once opentelemetry includes support for `protobuf 5.27.x`.
### Install Protobuf Compiler
**Linux**
1. Download the version of the protobuf compiler matching the protobuf package, currently v4.25.4 so protoc v25.4, from [GitHub](https://github.com/protocolbuffers/protobuf/releases) -> `protobuf-25.4.zip`
2. Extract the files under `$HOME/.local` or another directory of your choice
```bash
unzip protoc-<version>-linux-x86_64.zip -d $HOME/.local
```
3. Ensure that the `bin` directory is in your `PATH` by adding the following line to your `.bashrc` or `.zshrc`:
```bash
export PATH="$PATH:$HOME/.local/bin"
```
**MacOS**
1. Download the version of the protobuf compiler matching the protobuf package, currently v4.25.4 so protoc v25.4, from [GitHub](https://github.com/protocolbuffers/protobuf/releases) -> `protoc-25.4-osx-universal_binary.zip`
2. Extract the files to a directory of your choice
3. Copy the executable bin `protoc` to `/usr/local/bin`
```bash
sudo cp /Users/you/location-of-unzipped-dir/bin/protoc /usr/local/bin/
```
4. Open `protoc` in `/usr/local/bin/` via Finder to make it executable, now it should be also on your `PATH`
### Compile Protobuf Files
1. Ensure that the protobuf compiler is installed on your system. You can check this by running:
```bash
protoc --version
```
2. Compile proto files:
```bash
protoc --proto_path=./config/proto --python_out=./pyinfra/proto ./config/proto/*.proto
```
3. Manually adjust import statements in the generated files to match the package structure, e.g.:
`import EntryData_pb2 as EntryData__pb2` -> `import pyinfra.proto.EntryData_pb2 as EntryData__pb2`.
This does not work automatically because the generated files are not in the same directory as the proto files.
**Note**: Status 2025/01/09: the currently used `opentelemetry-exporter-otlp-proto-http` version `1.25.0` requires
a `protobuf` version < `5.x.x` and is not compatible with the latest protobuf version `5.27.x`. This is an [open issue](https://github.com/open-telemetry/opentelemetry-python/issues/3958) in opentelemetry, because [support for 4.25.x ends in Q2 '25](https://protobuf.dev/support/version-support/#python).
Therefore, we should keep this in mind and update the dependency once opentelemetry includes support for `protobuf 5.27.x`.

11934
bom.json

File diff suppressed because it is too large Load Diff

View File

@ -1,21 +0,0 @@
syntax = "proto3";
message AllDocumentPages {
repeated DocumentPage documentPages = 1;
}
message DocumentPage {
// The page number, starting with 1.
int32 number = 1;
// The page height in PDF user units.
int32 height = 2;
// The page width in PDF user units.
int32 width = 3;
// The page rotation as specified by the PDF.
int32 rotation = 4;
}

View File

@ -1,28 +0,0 @@
syntax = "proto3";
message AllDocumentPositionData {
repeated DocumentPositionData documentPositionData = 1;
}
message DocumentPositionData {
// Identifier of the text block.
int64 id = 1;
// For each string coordinate in the search text of the text block, the array contains an entry relating the string coordinate to the position coordinate.
// This is required due to the text and position coordinates not being equal.
repeated int32 stringIdxToPositionIdx = 2;
// The bounding box for each glyph as a rectangle. This matrix is of size (n,4), where n is the number of glyphs in the text block.
// The second dimension specifies the rectangle with the value x, y, width, height, with x, y specifying the lower left corner.
// In order to access this information, the stringIdxToPositionIdx array must be used to transform the coordinates.
repeated Position positions = 3;
// Definition of a BoundingBox that contains x, y, width, and height.
message Position {
float x = 1;
float y = 2;
float width = 3;
float height = 4;
}
}

View File

@ -1,8 +0,0 @@
syntax = "proto3";
import "EntryData.proto";
message DocumentStructure {
// The root EntryData represents the Document.
EntryData root = 1;
}

View File

@ -1,29 +0,0 @@
syntax = "proto3";
message AllDocumentTextData {
repeated DocumentTextData documentTextData = 1;
}
message DocumentTextData {
// Identifier of the text block.
int64 id = 1;
// The page the text block occurs on.
int64 page = 2;
// The text of the text block.
string searchText = 3;
// Each text block is assigned a number on a page, starting from 0.
int32 numberOnPage = 4;
// The text blocks are ordered, this number represents the start of the text block as a string offset.
int32 start = 5;
// The text blocks are ordered, this number represents the end of the text block as a string offset.
int32 end = 6;
// The line breaks in the text of this semantic node in string offsets. They are exclusive end. At the end of each semantic node there is an implicit linebreak.
repeated int32 lineBreaks = 7;
}

View File

@ -1,27 +0,0 @@
syntax = "proto3";
import "LayoutEngine.proto";
import "NodeType.proto";
message EntryData {
// Type of the semantic node.
NodeType type = 1;
// Specifies the position in the parsed tree structure.
repeated int32 treeId = 2;
// Specifies the text block IDs associated with this semantic node.
repeated int64 atomicBlockIds = 3;
// Specifies the pages this semantic node appears on.
repeated int64 pageNumbers = 4;
// Some semantic nodes have additional information, this information is stored in this Map.
map<string, string> properties = 5;
// All child Entries of this Entry.
repeated EntryData children = 6;
// Describes the origin of the semantic node.
repeated LayoutEngine engines = 7;
}

View File

@ -1,7 +0,0 @@
syntax = "proto3";
enum LayoutEngine {
ALGORITHM = 0;
AI = 1;
OUTLINE = 2;
}

View File

@ -1,14 +0,0 @@
syntax = "proto3";
enum NodeType {
DOCUMENT = 0;
SECTION = 1;
SUPER_SECTION = 2;
HEADLINE = 3;
PARAGRAPH = 4;
TABLE = 5;
TABLE_CELL = 6;
IMAGE = 7;
HEADER = 8;
FOOTER = 9;
}

4937
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,14 +1,17 @@
import asyncio
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import signal
import sys
import aiohttp
from aiormq.exceptions import AMQPConnectionError
from dynaconf import Dynaconf
from fastapi import FastAPI
from kn_utils.logging import logger
from pyinfra.config.loader import get_pyinfra_validators, validate_settings
from pyinfra.queue.async_manager import AsyncQueueManager, RabbitMQConfig
from pyinfra.queue.manager import QueueManager
from pyinfra.queue.callback import Callback
from pyinfra.queue.manager import QueueManager
from pyinfra.utils.opentelemetry import instrument_app, instrument_pika, setup_trace
from pyinfra.webserver.prometheus import (
add_prometheus_endpoint,
@ -20,38 +23,75 @@ from pyinfra.webserver.utils import (
run_async_webserver,
)
shutdown_flag = False
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((Exception,)), # You might want to be more specific here
reraise=True,
)
async def run_async_queues(manager, app, port, host):
async def graceful_shutdown(manager: AsyncQueueManager, queue_task, webserver_task):
global shutdown_flag
shutdown_flag = True
logger.info("SIGTERM received, shutting down gracefully...")
if queue_task and not queue_task.done():
queue_task.cancel()
# await queue manager shutdown
await asyncio.gather(queue_task, manager.shutdown(), return_exceptions=True)
if webserver_task and not webserver_task.done():
webserver_task.cancel()
# await webserver shutdown
await asyncio.gather(webserver_task, return_exceptions=True)
logger.info("Shutdown complete.")
async def run_async_queues(manager: AsyncQueueManager, app, port, host):
"""Run the async webserver and the async queue manager concurrently."""
queue_task = None
webserver_task = None
tenant_api_available = True
# add signal handler for SIGTERM and SIGINT
loop = asyncio.get_running_loop()
loop.add_signal_handler(
signal.SIGTERM, lambda: asyncio.create_task(graceful_shutdown(manager, queue_task, webserver_task))
)
loop.add_signal_handler(
signal.SIGINT, lambda: asyncio.create_task(graceful_shutdown(manager, queue_task, webserver_task))
)
try:
await manager.run()
await run_async_webserver(app, port, host)
active_tenants = await manager.fetch_active_tenants()
queue_task = asyncio.create_task(manager.run(active_tenants=active_tenants), name="queues")
webserver_task = asyncio.create_task(run_async_webserver(app, port, host), name="webserver")
await asyncio.gather(queue_task, webserver_task)
except asyncio.CancelledError:
logger.info("Main task is cancelled.")
logger.info("Main task was cancelled, initiating shutdown.")
except AMQPConnectionError as e:
logger.warning(f"AMQPConnectionError: {e} - shutting down.")
except (aiohttp.ClientResponseError, aiohttp.ClientConnectorError):
logger.warning("Tenant server did not answer - shutting down.")
tenant_api_available = False
except Exception as e:
logger.error(f"An error occurred while running async queues: {e}", exc_info=True)
sys.exit(1)
finally:
logger.info("Signal received, shutting down...")
await manager.shutdown()
if shutdown_flag:
logger.debug("Graceful shutdown already in progress.")
else:
logger.warning("Initiating shutdown due to error or manual interruption.")
if not tenant_api_available:
sys.exit(0)
if queue_task and not queue_task.done():
queue_task.cancel()
if webserver_task and not webserver_task.done():
webserver_task.cancel()
# async def run_async_queues(manager, app, port, host):
# server = None
# try:
# await manager.run()
# server = await asyncio.start_server(app, host, port)
# await server.serve_forever()
# except Exception as e:
# logger.error(f"An error occurred while running async queues: {e}")
# finally:
# if server:
# server.close()
# await server.wait_closed()
# await manager.shutdown()
await asyncio.gather(queue_task, manager.shutdown(), webserver_task, return_exceptions=True)
logger.info("Shutdown complete.")
def start_standard_queue_consumer(
@ -80,10 +120,11 @@ def start_standard_queue_consumer(
if settings.tracing.enabled:
setup_trace(settings)
instrument_pika()
instrument_pika(dynamic_queues=settings.dynamic_tenant_queues.enabled)
instrument_app(app)
if settings.dynamic_tenant_queues.enabled:
logger.info("Dynamic tenant queues enabled. Running async queues.")
config = RabbitMQConfig(
host=settings.rabbitmq.host,
port=settings.rabbitmq.port,
@ -100,9 +141,15 @@ def start_standard_queue_consumer(
pod_name=settings.kubernetes.pod_name,
)
manager = AsyncQueueManager(
config=config, tenant_service_url=settings.storage.tenant_server.endpoint, message_processor=callback
config=config,
tenant_service_url=settings.storage.tenant_server.endpoint,
message_processor=callback,
max_concurrent_tasks=(
settings.asyncio.max_concurrent_tasks if hasattr(settings.asyncio, "max_concurrent_tasks") else 10
),
)
else:
logger.info("Dynamic tenant queues disabled. Running sync queues.")
manager = QueueManager(settings)
app = add_health_check_endpoint(app, manager.is_ready)
@ -116,9 +163,7 @@ def start_standard_queue_consumer(
try:
manager.start_consuming(callback)
except Exception as e:
logger.error(f"An error occurred while consuming messages: {e}")
# Optionally, you can choose to exit here if you want to restart the process
# import sys
# sys.exit(1)
logger.error(f"An error occurred while consuming messages: {e}", exc_info=True)
sys.exit(1)
else:
logger.warning(f"Behavior for type {type(manager)} is not defined")

View File

@ -1,29 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: DocumentPage.proto
# Protobuf Python Version: 4.25.4
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x12\x44ocumentPage.proto"8\n\x10\x41llDocumentPages\x12$\n\rdocumentPages\x18\x01 \x03(\x0b\x32\r.DocumentPage"O\n\x0c\x44ocumentPage\x12\x0e\n\x06number\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x10\n\x08rotation\x18\x04 \x01(\x05\x62\x06proto3'
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "DocumentPage_pb2", _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_ALLDOCUMENTPAGES"]._serialized_start = 22
_globals["_ALLDOCUMENTPAGES"]._serialized_end = 78
_globals["_DOCUMENTPAGE"]._serialized_start = 80
_globals["_DOCUMENTPAGE"]._serialized_end = 159
# @@protoc_insertion_point(module_scope)

View File

@ -1,31 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: DocumentPositionData.proto
# Protobuf Python Version: 4.25.4
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x1a\x44ocumentPositionData.proto"N\n\x17\x41llDocumentPositionData\x12\x33\n\x14\x64ocumentPositionData\x18\x01 \x03(\x0b\x32\x15.DocumentPositionData"\xb6\x01\n\x14\x44ocumentPositionData\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x1e\n\x16stringIdxToPositionIdx\x18\x02 \x03(\x05\x12\x31\n\tpositions\x18\x03 \x03(\x0b\x32\x1e.DocumentPositionData.Position\x1a?\n\x08Position\x12\t\n\x01x\x18\x01 \x01(\x02\x12\t\n\x01y\x18\x02 \x01(\x02\x12\r\n\x05width\x18\x03 \x01(\x02\x12\x0e\n\x06height\x18\x04 \x01(\x02\x62\x06proto3'
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "DocumentPositionData_pb2", _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_ALLDOCUMENTPOSITIONDATA"]._serialized_start = 30
_globals["_ALLDOCUMENTPOSITIONDATA"]._serialized_end = 108
_globals["_DOCUMENTPOSITIONDATA"]._serialized_start = 111
_globals["_DOCUMENTPOSITIONDATA"]._serialized_end = 293
_globals["_DOCUMENTPOSITIONDATA_POSITION"]._serialized_start = 230
_globals["_DOCUMENTPOSITIONDATA_POSITION"]._serialized_end = 293
# @@protoc_insertion_point(module_scope)

View File

@ -1,29 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: DocumentStructure.proto
# Protobuf Python Version: 4.25.4
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
import pyinfra.proto.EntryData_pb2 as EntryData__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x17\x44ocumentStructure.proto\x1a\x0f\x45ntryData.proto"-\n\x11\x44ocumentStructure\x12\x18\n\x04root\x18\x01 \x01(\x0b\x32\n.EntryDatab\x06proto3'
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "DocumentStructure_pb2", _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_DOCUMENTSTRUCTURE"]._serialized_start = 44
_globals["_DOCUMENTSTRUCTURE"]._serialized_end = 89
# @@protoc_insertion_point(module_scope)

View File

@ -1,29 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: DocumentTextData.proto
# Protobuf Python Version: 4.25.4
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x16\x44ocumentTextData.proto"B\n\x13\x41llDocumentTextData\x12+\n\x10\x64ocumentTextData\x18\x01 \x03(\x0b\x32\x11.DocumentTextData"\x86\x01\n\x10\x44ocumentTextData\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0c\n\x04page\x18\x02 \x01(\x03\x12\x12\n\nsearchText\x18\x03 \x01(\t\x12\x14\n\x0cnumberOnPage\x18\x04 \x01(\x05\x12\r\n\x05start\x18\x05 \x01(\x05\x12\x0b\n\x03\x65nd\x18\x06 \x01(\x05\x12\x12\n\nlineBreaks\x18\x07 \x03(\x05\x62\x06proto3'
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "DocumentTextData_pb2", _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_ALLDOCUMENTTEXTDATA"]._serialized_start = 26
_globals["_ALLDOCUMENTTEXTDATA"]._serialized_end = 92
_globals["_DOCUMENTTEXTDATA"]._serialized_start = 95
_globals["_DOCUMENTTEXTDATA"]._serialized_end = 229
# @@protoc_insertion_point(module_scope)

View File

@ -1,34 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: EntryData.proto
# Protobuf Python Version: 4.25.4
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
import pyinfra.proto.LayoutEngine_pb2 as LayoutEngine__pb2
import pyinfra.proto.NodeType_pb2 as NodeType__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x0f\x45ntryData.proto\x1a\x12LayoutEngine.proto\x1a\x0eNodeType.proto"\x82\x02\n\tEntryData\x12\x17\n\x04type\x18\x01 \x01(\x0e\x32\t.NodeType\x12\x0e\n\x06treeId\x18\x02 \x03(\x05\x12\x16\n\x0e\x61tomicBlockIds\x18\x03 \x03(\x03\x12\x13\n\x0bpageNumbers\x18\x04 \x03(\x03\x12.\n\nproperties\x18\x05 \x03(\x0b\x32\x1a.EntryData.PropertiesEntry\x12\x1c\n\x08\x63hildren\x18\x06 \x03(\x0b\x32\n.EntryData\x12\x1e\n\x07\x65ngines\x18\x07 \x03(\x0e\x32\r.LayoutEngine\x1a\x31\n\x0fPropertiesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x62\x06proto3'
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "EntryData_pb2", _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_ENTRYDATA_PROPERTIESENTRY"]._options = None
_globals["_ENTRYDATA_PROPERTIESENTRY"]._serialized_options = b"8\001"
_globals["_ENTRYDATA"]._serialized_start = 56
_globals["_ENTRYDATA"]._serialized_end = 314
_globals["_ENTRYDATA_PROPERTIESENTRY"]._serialized_start = 265
_globals["_ENTRYDATA_PROPERTIESENTRY"]._serialized_end = 314
# @@protoc_insertion_point(module_scope)

View File

@ -1,27 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: LayoutEngine.proto
# Protobuf Python Version: 4.25.4
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b"\n\x12LayoutEngine.proto*2\n\x0cLayoutEngine\x12\r\n\tALGORITHM\x10\x00\x12\x06\n\x02\x41I\x10\x01\x12\x0b\n\x07OUTLINE\x10\x02\x62\x06proto3"
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "LayoutEngine_pb2", _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_LAYOUTENGINE"]._serialized_start = 22
_globals["_LAYOUTENGINE"]._serialized_end = 72
# @@protoc_insertion_point(module_scope)

View File

@ -1,27 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: NodeType.proto
# Protobuf Python Version: 4.25.4
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b"\n\x0eNodeType.proto*\x93\x01\n\x08NodeType\x12\x0c\n\x08\x44OCUMENT\x10\x00\x12\x0b\n\x07SECTION\x10\x01\x12\x11\n\rSUPER_SECTION\x10\x02\x12\x0c\n\x08HEADLINE\x10\x03\x12\r\n\tPARAGRAPH\x10\x04\x12\t\n\x05TABLE\x10\x05\x12\x0e\n\nTABLE_CELL\x10\x06\x12\t\n\x05IMAGE\x10\x07\x12\n\n\x06HEADER\x10\x08\x12\n\n\x06\x46OOTER\x10\tb\x06proto3"
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "NodeType_pb2", _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_NODETYPE"]._serialized_start = 19
_globals["_NODETYPE"]._serialized_end = 166
# @@protoc_insertion_point(module_scope)

View File

@ -1,13 +1,11 @@
import asyncio
import concurrent.futures
import json
import signal
import sys
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Set
import aiohttp
from aio_pika import ExchangeType, IncomingMessage, Message, connect_robust
from aio_pika import ExchangeType, IncomingMessage, Message, connect
from aio_pika.abc import (
AbstractChannel,
AbstractConnection,
@ -15,17 +13,14 @@ from aio_pika.abc import (
AbstractIncomingMessage,
AbstractQueue,
)
from kn_utils.logging import logger
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
wait_exponential,
retry_if_exception_type,
from aio_pika.exceptions import (
ChannelClosed,
ChannelInvalidStateError,
ConnectionClosed,
)
from aio_pika.exceptions import AMQPConnectionError, ChannelInvalidStateError
from aiormq.exceptions import AMQPConnectionError
from kn_utils.logging import logger
from kn_utils.retry import retry
@dataclass
@ -62,10 +57,12 @@ class AsyncQueueManager:
config: RabbitMQConfig,
tenant_service_url: str,
message_processor: Callable[[Dict[str, Any]], Dict[str, Any]],
max_concurrent_tasks: int = 10,
):
self.config = config
self.tenant_service_url = tenant_service_url
self.message_processor = message_processor
self.semaphore = asyncio.Semaphore(max_concurrent_tasks)
self.connection: AbstractConnection | None = None
self.channel: AbstractChannel | None = None
@ -78,19 +75,32 @@ class AsyncQueueManager:
self.message_count: int = 0
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(AMQPConnectionError),
reraise=True,
)
@retry(tries=5, exceptions=AMQPConnectionError, reraise=True, logger=logger)
async def connect(self) -> None:
logger.info("Attempting to connect to RabbitMQ...")
self.connection = await connect_robust(**self.config.connection_params)
self.connection = await connect(**self.config.connection_params)
self.connection.close_callbacks.add(self.on_connection_close)
self.channel = await self.connection.channel()
await self.channel.set_qos(prefetch_count=1)
logger.info("Successfully connected to RabbitMQ")
async def on_connection_close(self, sender, exc):
"""This is a callback for unexpected connection closures."""
logger.debug(f"Sender: {sender}")
if isinstance(exc, ConnectionClosed):
logger.warning("Connection to RabbitMQ lost. Attempting to reconnect...")
try:
active_tenants = await self.fetch_active_tenants()
await self.run(active_tenants=active_tenants)
logger.debug("Reconnected to RabbitMQ successfully")
except Exception as e:
logger.warning(f"Failed to reconnect to RabbitMQ: {e}")
# cancel queue manager and webserver to shutdown service
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
[task.cancel() for task in tasks if task.get_name() in ["queues", "webserver"]]
else:
logger.debug("Connection closed on purpose.")
async def is_ready(self) -> bool:
if self.connection is None or self.connection.is_closed:
try:
@ -100,12 +110,7 @@ class AsyncQueueManager:
return False
return True
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((AMQPConnectionError, ChannelInvalidStateError)),
reraise=True,
)
@retry(tries=5, exceptions=(AMQPConnectionError, ChannelInvalidStateError), reraise=True, logger=logger)
async def setup_exchanges(self) -> None:
self.tenant_exchange = await self.channel.declare_exchange(
self.config.tenant_exchange_name, ExchangeType.TOPIC, durable=True
@ -117,12 +122,12 @@ class AsyncQueueManager:
self.config.service_response_exchange_name, ExchangeType.DIRECT, durable=True
)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((AMQPConnectionError, ChannelInvalidStateError)),
reraise=True,
)
# we must declare DLQ to handle error messages
self.dead_letter_queue = await self.channel.declare_queue(
self.config.service_dead_letter_queue_name, durable=True
)
@retry(tries=5, exceptions=(AMQPConnectionError, ChannelInvalidStateError), reraise=True, logger=logger)
async def setup_tenant_queue(self) -> None:
self.tenant_exchange_queue = await self.channel.declare_queue(
f"{self.config.pod_name}_{self.config.tenant_event_queue_suffix}",
@ -160,6 +165,10 @@ class AsyncQueueManager:
input_queue = await self.channel.declare_queue(
queue_name,
durable=True,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.config.service_dead_letter_queue_name,
},
)
await input_queue.bind(self.input_exchange, routing_key=tenant_id)
self.consumer_tags[tenant_id] = await input_queue.consume(self.process_input_message)
@ -178,11 +187,14 @@ class AsyncQueueManager:
async def process_input_message(self, message: IncomingMessage) -> None:
async def process_message_body_and_await_result(unpacked_message_body):
loop = asyncio.get_running_loop()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.info("Processing payload in a separate thread.")
result = await loop.run_in_executor(thread_pool_executor, self.message_processor, unpacked_message_body)
return result
async with self.semaphore:
loop = asyncio.get_running_loop()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.info("Processing payload in a separate thread.")
result = await loop.run_in_executor(
thread_pool_executor, self.message_processor, unpacked_message_body
)
return result
async with message.process(ignore_processed=True):
if message.redelivered:
@ -222,14 +234,13 @@ class AsyncQueueManager:
except json.JSONDecodeError:
await message.nack(requeue=False)
logger.error(f"Invalid JSON in input message: {message.body}")
logger.error(f"Invalid JSON in input message: {message.body}", exc_info=True)
except FileNotFoundError as e:
logger.warning(f"{e}, declining message with {message.delivery_tag=}.")
logger.warning(f"{e}, declining message with {message.delivery_tag=}.", exc_info=True)
await message.nack(requeue=False)
except Exception as e:
await message.nack(requeue=False)
logger.error(f"Error processing input message: {e}", exc_info=True)
raise
finally:
self.message_count -= 1
@ -240,12 +251,7 @@ class AsyncQueueManager:
)
logger.info(f"Published result to queue {tenant_id}.")
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential_jitter(initial=1, max=10),
retry=retry_if_exception_type((aiohttp.ClientResponseError, aiohttp.ClientConnectorError)),
reraise=True,
)
@retry(tries=5, exceptions=(aiohttp.ClientResponseError, aiohttp.ClientConnectorError), reraise=True, logger=logger)
async def fetch_active_tenants(self) -> Set[str]:
async with aiohttp.ClientSession() as session:
async with session.get(self.tenant_service_url) as response:
@ -260,49 +266,64 @@ class AsyncQueueManager:
return set()
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((AMQPConnectionError, ChannelInvalidStateError)),
tries=5,
exceptions=(
AMQPConnectionError,
ChannelInvalidStateError,
),
reraise=True,
logger=logger,
)
async def initialize_tenant_queues(self) -> None:
try:
active_tenants = await self.fetch_active_tenants()
except (aiohttp.ClientResponseError, aiohttp.ClientConnectorError):
logger.warning("API calls to tenant server failed. No tenant queues initialized.")
active_tenants = set()
async def initialize_tenant_queues(self, active_tenants: set) -> None:
for tenant_id in active_tenants:
await self.create_tenant_queues(tenant_id)
async def run(self) -> None:
try:
await self.connect()
await self.setup_exchanges()
await self.initialize_tenant_queues()
await self.setup_tenant_queue()
async def run(self, active_tenants: set) -> None:
logger.info("RabbitMQ handler is running. Press CTRL+C to exit.")
except AMQPConnectionError as e:
logger.error(f"Failed to establish connection to RabbitMQ: {e}")
# TODO: implement a custom exception handling strategy here
except asyncio.CancelledError:
logger.warning("Operation cancelled.")
await self.connect()
await self.setup_exchanges()
await self.initialize_tenant_queues(active_tenants=active_tenants)
await self.setup_tenant_queue()
logger.info("RabbitMQ handler is running. Press CTRL+C to exit.")
async def close_channels(self) -> None:
try:
if self.channel and not self.channel.is_closed:
# Cancel queues to stop fetching messages
logger.debug("Cancelling queues...")
for tenant, queue in self.tenant_queues.items():
await queue.cancel(self.consumer_tags[tenant])
if self.tenant_exchange_queue:
await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"])
while self.message_count != 0:
logger.debug(f"Messages are still being processed: {self.message_count=} ")
await asyncio.sleep(2)
await self.channel.close(exc=asyncio.CancelledError)
logger.debug("Channel closed.")
else:
logger.debug("No channel to close.")
except ChannelClosed:
logger.warning("Channel was already closed.")
except ConnectionClosed:
logger.warning("Connection was lost, unable to close channel.")
except Exception as e:
logger.error(f"An error occurred: {e}", exc_info=True)
logger.error(f"Error during channel shutdown: {e}")
async def close_connection(self) -> None:
try:
if self.connection and not self.connection.is_closed:
await self.connection.close(exc=asyncio.CancelledError)
logger.debug("Connection closed.")
else:
logger.debug("No connection to close.")
except ConnectionClosed:
logger.warning("Connection was already closed.")
except Exception as e:
logger.error(f"Error closing connection: {e}")
async def shutdown(self) -> None:
logger.info("Shutting down RabbitMQ handler...")
if self.channel:
# Cancel queues to stop fetching messages
logger.debug("Cancelling queues...")
for tenant, queue in self.tenant_queues.items():
await queue.cancel(self.consumer_tags[tenant])
await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"])
while self.message_count != 0:
logger.debug(f"Messages are still being processed: {self.message_count=} ")
await asyncio.sleep(2)
await self.channel.close()
if self.connection:
await self.connection.close()
await self.close_channels()
await self.close_connection()
logger.info("RabbitMQ handler shut down successfully.")
sys.exit(0)

View File

@ -1,15 +1,16 @@
from typing import Callable, Union
from typing import Callable
from dynaconf import Dynaconf
from kn_utils.logging import logger
from pyinfra.storage.connection import get_storage
from pyinfra.storage.utils import (
download_data_as_specified_in_message,
download_data_bytes_as_specified_in_message,
upload_data_as_specified_in_message,
DownloadedData,
)
DataProcessor = Callable[[Union[dict, bytes], dict], Union[dict, list, str]]
DataProcessor = Callable[[dict[str, DownloadedData] | DownloadedData, dict], dict | list | str]
Callback = Callable[[dict], dict]
@ -28,7 +29,9 @@ def make_download_process_upload_callback(data_processor: DataProcessor, setting
storage = get_storage(settings, queue_message_payload.get("X-TENANT-ID"))
data = download_data_as_specified_in_message(storage, queue_message_payload)
data: dict[str, DownloadedData] | DownloadedData = download_data_bytes_as_specified_in_message(
storage, queue_message_payload
)
result = data_processor(data, queue_message_payload)

View File

@ -10,9 +10,8 @@ import pika
import pika.exceptions
from dynaconf import Dynaconf
from kn_utils.logging import logger
from kn_utils.retry import retry
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
from retry import retry
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import queue_manager_validators
@ -59,9 +58,8 @@ class QueueManager:
return pika.ConnectionParameters(**pika_connection_params)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((pika.exceptions.AMQPConnectionError, pika.exceptions.ChannelClosedByBroker)),
tries=5,
exceptions=(pika.exceptions.AMQPConnectionError, pika.exceptions.ChannelClosedByBroker),
reraise=True,
)
def establish_connection(self):
@ -95,9 +93,8 @@ class QueueManager:
return False
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(pika.exceptions.AMQPConnectionError),
tries=5,
exceptions=pika.exceptions.AMQPConnectionError,
reraise=True,
)
def start_consuming(self, message_processor: Callable):

View File

@ -1,95 +0,0 @@
import re
from enum import Enum
from pathlib import Path
from google.protobuf.json_format import MessageToDict
from kn_utils.logging import logger
from pyinfra.proto import (
DocumentPage_pb2,
DocumentPositionData_pb2,
DocumentStructure_pb2,
DocumentTextData_pb2,
)
class ProtoDataLoader:
"""Loads proto data from a file and returns it as a dictionary or list.
The loader is a singleton and should be used as a callable. The file name and byte data are passed as arguments.
The document type is determined based on the file name and the data is returned as a dictionary or list, depending
on the document type.
The DocumentType enum contains all supported document types and their corresponding proto schema.
KEYS_TO_UNPACK contains the keys that should be unpacked from the message dictionary. Keys are unpacked if the
message dictionary contains only one key. This behaviour is necessary since lists are wrapped in a dictionary.
"""
_instance = None
_pattern = None
class DocumentType(Enum):
STRUCTURE = (DocumentStructure_pb2.DocumentStructure, "DocumentStructure")
TEXT = (DocumentTextData_pb2.AllDocumentTextData, "AllDocumentTextData")
PAGES = (DocumentPage_pb2.AllDocumentPages, "AllDocumentPages")
POSITION = (DocumentPositionData_pb2.AllDocumentPositionData, "AllDocumentPositionData")
KEYS_TO_UNPACK = ["documentTextData", "documentPages", "documentPositionData"]
@classmethod
def _build_pattern(cls) -> re.Pattern:
types = "|".join([dt.name for dt in cls.DocumentType])
return re.compile(rf"\..*({types}).*\.proto.*")
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._pattern = cls._build_pattern()
return cls._instance
def __call__(self, file_name: str | Path, data: bytes) -> dict:
return self._load(file_name, data)
def _load(self, file_name: str | Path, data: bytes) -> dict | list:
file_name = str(file_name)
document_type = self._match(file_name)
if not document_type:
logger.error(f"Unknown document type: {file_name}, supported types: {self.DocumentType}")
return {}
logger.debug(f"Loading document type: {document_type}")
schema, _ = self.DocumentType[document_type].value
message = schema()
message.ParseFromString(data)
message_dict = MessageToDict(message, including_default_value_fields=True)
message_dict = convert_int64_fields(message_dict)
return self._unpack(message_dict)
def _match(self, file_name: str) -> str | None:
match = self._pattern.search(file_name)
return match.group(1) if match else None
def _unpack(self, message_dict: dict) -> list | dict:
if len(message_dict) > 1:
return message_dict
for key in self.KEYS_TO_UNPACK:
if key in message_dict:
logger.debug(f"Unpacking key: {key}")
return message_dict[key]
return message_dict
def convert_int64_fields(obj):
# FIXME: find a more sophisticated way to convert int64 fields (defaults to str in python)
if isinstance(obj, dict):
for key, value in obj.items():
obj[key] = convert_int64_fields(value)
elif isinstance(obj, list):
return [convert_int64_fields(item) for item in obj]
elif isinstance(obj, str) and obj.isdigit():
return int(obj)
return obj

View File

@ -1,12 +1,11 @@
import gzip
import json
from functools import singledispatch
from typing import Union
from typing import TypedDict
from kn_utils.logging import logger
from pydantic import BaseModel, ValidationError
from pyinfra.storage.proto_data_loader import ProtoDataLoader
from pyinfra.storage.storages.storage import Storage
@ -53,28 +52,27 @@ class TenantIdDossierIdFileIdUploadPayload(BaseModel):
class TargetResponseFilePathDownloadPayload(BaseModel):
targetFilePath: Union[str, dict]
targetFilePath: str | dict[str, str]
class TargetResponseFilePathUploadPayload(BaseModel):
responseFilePath: str
def download_data_as_specified_in_message(storage: Storage, raw_payload: dict) -> Union[dict, bytes]:
class DownloadedData(TypedDict):
data: bytes
file_path: str
def download_data_bytes_as_specified_in_message(
storage: Storage, raw_payload: dict
) -> dict[str, DownloadedData] | DownloadedData:
"""Convenience function to download a file specified in a message payload.
Supports both legacy and new payload formats. Also supports downloading multiple files at once, which should
be specified in a dictionary under the 'targetFilePath' key with the file path as value.
If the content is compressed with gzip (.gz), it will be decompressed (-> bytes).
If the content is a json file, it will be decoded (-> dict).
If no file is specified in the payload or the file does not exist in storage, an exception will be raised.
In all other cases, the content will be returned as is (-> bytes).
This function can be extended in the future as needed (e.g. handling of more file types), but since further
requirements are not specified at this point in time, and it is unclear what these would entail, the code is kept
simple for now to improve readability, maintainability and avoid refactoring efforts of generic solutions that
weren't as generic as they seemed.
The data is downloaded as bytes and returned as a dictionary with the file path as key and the data as value.
In case of several download targets, a nested dictionary is returned with the same keys and dictionaries with
the file path and data as values.
"""
try:
@ -93,33 +91,25 @@ def download_data_as_specified_in_message(storage: Storage, raw_payload: dict) -
@singledispatch
def _download(file_path_or_file_path_dict: Union[str, dict], storage: Storage) -> Union[dict, bytes]:
def _download(
file_path_or_file_path_dict: str | dict[str, str], storage: Storage
) -> dict[str, DownloadedData] | DownloadedData:
pass
@_download.register(str)
def _download_single_file(file_path: str, storage: Storage) -> bytes:
def _download_single_file(file_path: str, storage: Storage) -> DownloadedData:
if not storage.exists(file_path):
raise FileNotFoundError(f"File '{file_path}' does not exist in storage.")
data = storage.get_object(file_path)
data = gzip.decompress(data) if ".gz" in file_path else data
if ".json" in file_path:
data = json.loads(data.decode("utf-8"))
elif ".proto" in file_path:
data = ProtoDataLoader()(file_path, data)
else:
pass # identity for other file types
logger.info(f"Downloaded {file_path} from storage.")
return data
return DownloadedData(data=data, file_path=file_path)
@_download.register(dict)
def _download_multiple_files(file_path_dict: dict, storage: Storage) -> dict:
def _download_multiple_files(file_path_dict: dict, storage: Storage) -> dict[str, DownloadedData]:
return {key: _download(value, storage) for key, value in file_path_dict.items()}

View File

@ -3,8 +3,10 @@ import json
from azure.monitor.opentelemetry import configure_azure_monitor
from dynaconf import Dynaconf
from fastapi import FastAPI
from kn_utils.logging import logger
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.pika import PikaInstrumentor
from opentelemetry.sdk.resources import Resource
@ -18,7 +20,6 @@ from opentelemetry.sdk.trace.export import (
from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import opentelemetry_validators
from kn_utils.logging import logger
class JsonSpanExporter(SpanExporter):
@ -37,7 +38,7 @@ class JsonSpanExporter(SpanExporter):
def setup_trace(settings: Dynaconf, service_name: str = None, exporter: SpanExporter = None):
tracing_type = settings.tracing.type
if tracing_type == "azure_monitor":
# Configure OpenTelemetry to use Azure Monitor with the
# Configure OpenTelemetry to use Azure Monitor with the
# APPLICATIONINSIGHTS_CONNECTION_STRING environment variable.
try:
configure_azure_monitor()
@ -84,8 +85,11 @@ def get_exporter(settings: Dynaconf):
)
def instrument_pika():
PikaInstrumentor().instrument()
def instrument_pika(dynamic_queues: bool):
if dynamic_queues:
AioPikaInstrumentor().instrument()
else:
PikaInstrumentor().instrument()
def instrument_app(app: FastAPI, excluded_urls: str = "/health,/ready,/prometheus"):

View File

@ -4,22 +4,28 @@ import logging
import signal
import threading
import time
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from typing import Callable
import uvicorn
from dynaconf import Dynaconf
from fastapi import FastAPI
from kn_utils.logging import logger
from kn_utils.retry import retry
from pyinfra.config.loader import validate_settings
from pyinfra.config.validators import webserver_validators
class PyInfraUvicornServer(uvicorn.Server):
# this is a workaround to enable custom signal handlers
# https://github.com/encode/uvicorn/issues/1579
def install_signal_handlers(self):
pass
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=2, max=30),
retry=retry_if_exception_type((Exception,)), # You might want to be more specific here
tries=5,
exceptions=Exception,
reraise=True,
)
def create_webserver_thread_from_settings(app: FastAPI, settings: Dynaconf) -> threading.Thread:
@ -54,22 +60,18 @@ def create_webserver_thread(app: FastAPI, port: int, host: str) -> threading.Thr
async def run_async_webserver(app: FastAPI, port: int, host: str):
"""Run the FastAPI web server async."""
config = uvicorn.Config(app, host=host, port=port, log_level=logging.WARNING)
server = uvicorn.Server(config)
async def shutdown(signal):
logger.info(f"Received signal {signal.name}, shutting down webserver...")
await app.shutdown()
await app.cleanup()
logger.info("Shutdown complete.")
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(shutdown(s)))
server = PyInfraUvicornServer(config)
try:
await server.serve()
except asyncio.CancelledError:
pass
logger.debug("Webserver was cancelled.")
server.should_exit = True
await server.shutdown()
except Exception as e:
logger.error(f"Error while running the webserver: {e}", exc_info=True)
finally:
logger.info("Webserver has been shut down.")
HealthFunction = Callable[[], bool]

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "pyinfra"
version = "3.2.7"
version = "4.1.0"
description = ""
authors = ["Team Research <research@knecon.com>"]
license = "All rights reseverd"
@ -18,31 +18,34 @@ azure-storage-blob = "^12.13"
# misc utils
funcy = "^2"
pycryptodome = "^3.19"
# research shared packages
kn-utils = { version = "^0.2.7", source = "gitlab-research" }
fastapi = "^0.109.0"
uvicorn = "^0.26.0"
# [tool.poetry.group.telemetry.dependencies]
opentelemetry-instrumentation-pika = "^0.46b0"
opentelemetry-exporter-otlp = "^1.25.0"
opentelemetry-instrumentation = "^0.46b0"
opentelemetry-api = "^1.25.0"
opentelemetry-sdk = "^1.25.0"
opentelemetry-exporter-otlp-proto-http = "^1.25.0"
opentelemetry-instrumentation-flask = "^0.46b0"
opentelemetry-instrumentation-requests = "^0.46b0"
opentelemetry-instrumentation-fastapi = "^0.46b0"
# DONT USE GROUPS BECAUSE THEY ARE NOT INSTALLED FOR PACKAGES
# [tool.poetry.group.internal.dependencies] <<< THIS IS NOT WORKING
kn-utils = { version = ">=0.4.0", source = "nexus" }
# We set all opentelemetry dependencies to lower bound because the image classification service depends on a protobuf version <4, but does not use proto files.
# Therefore, we allow latest possible protobuf version in the services which use proto files. As soon as the dependency issue is fixed set this to the latest possible opentelemetry version
opentelemetry-instrumentation-pika = ">=0.46b0,<0.50"
opentelemetry-exporter-otlp = ">=1.25.0,<1.29"
opentelemetry-instrumentation = ">=0.46b0,<0.50"
opentelemetry-api = ">=1.25.0,<1.29"
opentelemetry-sdk = ">=1.25.0,<1.29"
opentelemetry-exporter-otlp-proto-http = ">=1.25.0,<1.29"
opentelemetry-instrumentation-flask = ">=0.46b0,<0.50"
opentelemetry-instrumentation-requests = ">=0.46b0,<0.50"
opentelemetry-instrumentation-fastapi = ">=0.46b0,<0.50"
opentelemetry-instrumentation-aio-pika = ">=0.46b0,<0.50"
wcwidth = "<=0.2.12"
azure-monitor-opentelemetry = "^1.6.0"
protobuf = "^3.20"
aio-pika = "^9.4.2"
aiohttp = "^3.9.5"
tenacity = "^8.5.0"
# THIS IS NOT AVAILABLE FOR SERVICES THAT IMPLEMENT PYINFRA
[tool.poetry.group.dev.dependencies]
pytest = "^7"
ipykernel = "^6.26.0"
black = "^23.10"
black = "^24.10"
pylint = "^3"
coverage = "^7.3"
requests = "^2.31"
@ -51,6 +54,7 @@ cyclonedx-bom = "^4.1.1"
dvc = "^3.51.2"
dvc-azure = "^3.1.0"
deepdiff = "^7.0.1"
pytest-cov = "^5.0.0"
[tool.pytest.ini_options]
minversion = "6.0"
@ -85,12 +89,13 @@ disable = [
docstring-min-length = 3
[[tool.poetry.source]]
name = "PyPI"
name = "pypi-proxy"
url = "https://nexus.knecon.com/repository/pypi-proxy/simple"
priority = "primary"
[[tool.poetry.source]]
name = "gitlab-research"
url = "https://gitlab.knecon.com/api/v4/groups/19/-/packages/pypi/simple"
name = "nexus"
url = "https://nexus.knecon.com/repository/python/simple"
priority = "explicit"
[build-system]

17
scripts/send_sigterm.py Normal file
View File

@ -0,0 +1,17 @@
import os
import signal
import time
# BE CAREFUL WITH THIS SCRIPT - THIS SIMULATES A SIGTERM FROM KUBERNETES
target_pid = int(input("Enter the PID of the target script: "))
print(f"Sending SIGTERM to PID {target_pid}...")
time.sleep(1)
try:
os.kill(target_pid, signal.SIGTERM)
print("SIGTERM sent.")
except ProcessLookupError:
print("Process not found.")
except PermissionError:
print("Permission denied. Are you trying to signal a process you don't own?")

View File

@ -0,0 +1,39 @@
#!/bin/bash
python_version=$1
nexus_user=$2
nexus_password=$3
# cookiecutter https://gitlab.knecon.com/knecon/research/template-python-project.git --checkout master
# latest_dir=$(ls -td -- */ | head -n 1) # should be the dir cookiecutter just created
# cd $latest_dir
pyenv install $python_version
pyenv local $python_version
pyenv shell $python_version
# install poetry globally (PREFERRED), only need to install it once
# curl -sSL https://install.python-poetry.org | python3 -
# remember to update poetry once in a while
poetry self update
# install poetry in current python environment, can lead to multiple instances of poetry being installed on one system (DISPREFERRED)
# pip install --upgrade pip
# pip install poetry
poetry config virtualenvs.in-project true
poetry config installer.max-workers 10
poetry config repositories.pypi-proxy "https://nexus.knecon.com/repository/pypi-proxy/simple"
poetry config http-basic.pypi-proxy ${nexus_user} ${nexus_password}
poetry config repositories.nexus https://nexus.knecon.com/repository/python/simple
poetry config http-basic.nexus ${nexus_user} ${nexus_password}
poetry env use $(pyenv which python)
poetry install --with=dev
poetry update
source .venv/bin/activate
pre-commit install
pre-commit autoupdate

View File

@ -27,6 +27,8 @@ def storage(storage_backend, settings):
def queue_manager(settings):
settings.rabbitmq_heartbeat = 10
settings.connection_sleep = 5
settings.rabbitmq.max_retries = 3
settings.rabbitmq.max_delay = 10
queue_manager = QueueManager(settings)
yield queue_manager

View File

@ -1,6 +1,6 @@
outs:
- md5: 7d36b38a27b5b959beec9e0e772c14c4.dir
size: 23067894
nfiles: 8
- md5: 75cc98b7c8fcf782a7d4941594e6bc12.dir
size: 134913
nfiles: 9
hash: md5
path: data

View File

@ -3,7 +3,6 @@ from sys import stdout
from time import sleep
import pika
import pytest
from kn_utils.logging import logger
logger.remove()

View File

@ -7,7 +7,7 @@ from fastapi import FastAPI
from pyinfra.storage.connection import get_storage_for_tenant
from pyinfra.storage.utils import (
download_data_as_specified_in_message,
download_data_bytes_as_specified_in_message,
upload_data_as_specified_in_message,
)
from pyinfra.utils.cipher import encrypt
@ -139,16 +139,6 @@ def payload(payload_type):
}
@pytest.fixture
def expected_data(payload_type):
if payload_type == "target_response_file_path":
return {"data": "success"}
elif payload_type == "dossier_id_file_id":
return {"dossierId": "test", "fileId": "file", "data": "success"}
elif payload_type == "target_file_dict":
return {"file_1": {"data": "success"}, "file_2": {"data": "success"}}
@pytest.mark.parametrize(
"payload_type",
[
@ -160,17 +150,17 @@ def expected_data(payload_type):
)
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
class TestDownloadAndUploadFromMessage:
def test_download_and_upload_from_message(self, storage, payload, expected_data, payload_type):
def test_download_and_upload_from_message(self, storage, payload, payload_type):
storage.clear_bucket()
upload_data = expected_data if payload_type != "target_file_dict" else expected_data["file_1"]
storage.put_object("test/file.target.json.gz", gzip.compress(json.dumps(upload_data).encode()))
result = {"process_result": "success"}
storage_data = {**payload, "data": result}
packed_data = gzip.compress(json.dumps(storage_data).encode())
data = download_data_as_specified_in_message(storage, payload)
storage.put_object("test/file.target.json.gz", packed_data)
assert data == expected_data
upload_data_as_specified_in_message(storage, payload, expected_data)
_ = download_data_bytes_as_specified_in_message(storage, payload)
upload_data_as_specified_in_message(storage, payload, result)
data = json.loads(gzip.decompress(storage.get_object("test/file.response.json.gz")).decode())
assert data == {**payload, "data": expected_data}
assert data == storage_data

View File

@ -1,80 +0,0 @@
import gzip
import json
from pathlib import Path
import pytest
from deepdiff import DeepDiff
from pyinfra.storage.proto_data_loader import ProtoDataLoader
@pytest.fixture
def test_data_dir():
return Path(__file__).parents[1] / "data"
@pytest.fixture
def document_data(request, test_data_dir) -> (str, bytes, dict | list):
doc_type = request.param
input_file_path = test_data_dir / f"72ea04dfdbeb277f37b9eb127efb0896.{doc_type}.proto.gz"
target_file_path = test_data_dir / f"3f9d3d9f255007de8eff13648321e197.{doc_type}.json.gz"
input_data = input_file_path.read_bytes()
target_data = json.loads(gzip.decompress(target_file_path.read_bytes()))
return input_file_path, input_data, target_data
@pytest.fixture
def proto_data_loader():
return ProtoDataLoader()
@pytest.fixture
def should_match():
return [
"a.DOCUMENT_STRUCTURE.proto.gz",
"a.DOCUMENT_TEXT.proto.gz",
"a.DOCUMENT_PAGES.proto.gz",
"a.DOCUMENT_POSITION.proto.gz",
"b.DOCUMENT_STRUCTURE.proto",
"b.DOCUMENT_TEXT.proto",
"b.DOCUMENT_PAGES.proto",
"b.DOCUMENT_POSITION.proto",
"c.STRUCTURE.proto.gz",
"c.TEXT.proto.gz",
"c.PAGES.proto.gz",
"c.POSITION.proto.gz",
]
@pytest.mark.xfail(
reason="FIXME: The test is not stable, but hast to work before we can deploy the code! Right now, we don't have parity between the proto and the json data."
)
# As DOCUMENT_POSITION is a very large file, the test takes forever. If you want to test it, add "DOCUMENT_POSITION" to the list below.
@pytest.mark.parametrize("document_data", ["DOCUMENT_STRUCTURE", "DOCUMENT_TEXT", "DOCUMENT_PAGES"], indirect=True)
def test_proto_data_loader_end2end(document_data, proto_data_loader):
file_path, data, target = document_data
data = gzip.decompress(data)
loaded_data = proto_data_loader(file_path, data)
loaded_data_str = json.dumps(loaded_data, sort_keys=True)
target_str = json.dumps(target, sort_keys=True)
diff = DeepDiff(sorted(loaded_data_str), sorted(target_str), ignore_order=True)
# FIXME: remove this block when the test is stable
# if diff:
# with open("/tmp/diff.json", "w") as f:
# f.write(diff.to_json(indent=2))
assert not diff
def test_proto_data_loader_unknown_document_type(proto_data_loader):
assert not proto_data_loader("unknown_document_type.proto", b"")
def test_proto_data_loader_file_name_matching(proto_data_loader, should_match):
for file_name in should_match:
assert proto_data_loader._match(file_name) is not None

View File

@ -0,0 +1,83 @@
import json
import pytest
from unittest.mock import patch
from pyinfra.storage.utils import (
download_data_bytes_as_specified_in_message,
upload_data_as_specified_in_message,
DownloadedData,
)
from pyinfra.storage.storages.storage import Storage
@pytest.fixture
def mock_storage():
with patch("pyinfra.storage.utils.Storage") as MockStorage:
yield MockStorage()
@pytest.fixture(
params=[
{
"raw_payload": {
"tenantId": "tenant1",
"dossierId": "dossier1",
"fileId": "file1",
"targetFileExtension": "txt",
"responseFileExtension": "json",
},
"expected_result": {
"data": b'{"key": "value"}',
"file_path": "tenant1/dossier1/file1.txt"
}
},
{
"raw_payload": {
"targetFilePath": "some/path/to/file.txt.gz",
"responseFilePath": "some/path/to/file.json"
},
"expected_result": {
"data": b'{"key": "value"}',
"file_path": "some/path/to/file.txt.gz"
}
},
{
"raw_payload": {
"targetFilePath": {
"file1": "some/path/to/file1.txt.gz",
"file2": "some/path/to/file2.txt.gz"
},
"responseFilePath": "some/path/to/file.json"
},
"expected_result": {
"file1": {
"data": b'{"key": "value"}',
"file_path": "some/path/to/file1.txt.gz"
},
"file2": {
"data": b'{"key": "value"}',
"file_path": "some/path/to/file2.txt.gz"
}
}
},
]
)
def payload_and_expected_result(request):
return request.param
def test_download_data_bytes_as_specified_in_message(mock_storage, payload_and_expected_result):
raw_payload = payload_and_expected_result["raw_payload"]
expected_result = payload_and_expected_result["expected_result"]
mock_storage.get_object.return_value = b'{"key": "value"}'
result = download_data_bytes_as_specified_in_message(mock_storage, raw_payload)
assert isinstance(result, dict)
assert result == expected_result
mock_storage.get_object.assert_called()
def test_upload_data_as_specified_in_message(mock_storage, payload_and_expected_result):
raw_payload = payload_and_expected_result["raw_payload"]
data = {"key": "value"}
upload_data_as_specified_in_message(mock_storage, raw_payload, data)
mock_storage.put_object.assert_called_once()