Pull request #44: RED-4653
Merge in RR/pyinfra from RED-4653 to master
Squashed commit of the following:
commit 14ed6d2ee79f9a6bc4bad187dc775f7476a05d97
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Tue Jul 26 11:08:16 2022 +0200
RED-4653: Disabled coverage check since there not tests at the moment
commit e926631b167d03e8cc0867db5b5c7d44d6612dcf
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Tue Jul 26 10:58:50 2022 +0200
RED-4653: Re-added test execution scripts
commit 94648cc449bbc392864197a1796f99f8953b7312
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Tue Jul 26 10:50:42 2022 +0200
RED-4653: Changed error case for processing messages to not requeue the message since that will be handled in DLQ logic
commit d77982dfedcec49482293d79818283c8d7a17dc7
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Tue Jul 26 10:46:32 2022 +0200
RED-4653: Removed unnecessary logging message
commit 8c00fd75bf04f8ecc0e9cda654f8e053d4cfb66f
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Tue Jul 26 10:03:35 2022 +0200
RED-4653: Re-added wrongly removed config
commit 759d72b3fa093b19f97e68d17bf53390cd5453c7
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Tue Jul 26 09:57:47 2022 +0200
RED-4653: Removed leftover Docker commands
commit 2ff5897ee38e39d6507278b6a82176be2450da16
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Tue Jul 26 09:48:08 2022 +0200
RED-4653: Removed leftover Docker config
commit 1074167aa98f9f59c0f0f534ba2f1ba09ffb0958
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Tue Jul 26 09:41:21 2022 +0200
RED-4653: Removed Docker build stage since it is not needed for a project that is used as a Python module
commit ec769c6cd74a74097d8ebe4800ea6e2ea86236cc
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Mon Jul 25 16:11:50 2022 +0200
RED-4653: Renamed function for better clarity and consistency
commit 96e8ac4316ac57aac90066f35422d333c532513b
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Mon Jul 25 15:07:40 2022 +0200
RED-4653: Added code to cancel the queue subscription on application exit to queue manager so that it can exit gracefully
commit 64d8e0bd15730898274c08d34f9c34fbac559422
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Mon Jul 25 13:57:06 2022 +0200
RED-4653: Moved queue cancellation to a separate method so that it can be called on application exit
commit aff1d06364f5694c5922f37d961e401c12243221
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Mon Jul 25 11:51:16 2022 +0200
RED-4653: Re-ordered message processing so that ack occurs after publishing the result, to prevent message loss
commit 9339186b86f2fe9653366c22fcdc9f7fc096b138
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Fri Jul 22 18:07:25 2022 +0200
RED-4653: RED-4653: Reordered code to acknowledge message before publishing a result message
commit 2d6fe1cbd95cd86832b086c6dfbcfa62b3ffa16f
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Fri Jul 22 17:00:04 2022 +0200
RED-4653: Hopefully corrected storage bucket env var name
commit 8f1ef0dd5532882cb12901721195d9acb336286c
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Fri Jul 22 16:37:27 2022 +0200
RED-4653: Switched to validating the connection url via a regex since the validators lib parses our endpoints incorrectly
commit 8d0234fcc5ff7ed1ae7695a17856c6af050065bd
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Fri Jul 22 15:02:54 2022 +0200
RED-4653: Corrected exception creation
commit 098a62335b3b695ee409363d429ac07284de7138
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Fri Jul 22 14:42:22 2022 +0200
RED-4653: Added a descriptive error message when the storage endpoint is nor a correct url
commit 379685f964a4de641ce6506713f1ea8914a3f5ab
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Fri Jul 22 14:11:48 2022 +0200
RED-4653: Removed variable re-use to make the code clearer
commit 4bf1a023453635568e16b1678ef5ad994c534045
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Thu Jul 21 17:41:55 2022 +0200
RED-4653: Added explicit conversion of the heartbeat config value to an int before passing it to pika
commit 8f2bc4e028aafdef893458d1433a05724f534fce
Author: Viktor Seifert <viktor.seifert@iqser.com>
Date: Mon Jul 18 16:41:31 2022 +0200
RED-4653: Set heartbeat to lower value so that disconnects are detected more quickly
... and 6 more commits
This commit is contained in:
parent
3f645484d9
commit
e3abf2be0f
19
Dockerfile
19
Dockerfile
@ -1,19 +0,0 @@
|
||||
FROM python:3.8
|
||||
|
||||
# Use a virtual environment.
|
||||
RUN python -m venv /app/venv
|
||||
ENV PATH="/app/venv/bin:$PATH"
|
||||
|
||||
# Upgrade pip.
|
||||
RUN python -m pip install --upgrade pip
|
||||
|
||||
# Make a directory for the service files and copy the service repo into the container.
|
||||
WORKDIR /app/service
|
||||
COPY . .
|
||||
|
||||
# Install module & dependencies
|
||||
RUN python3 -m pip install -e .
|
||||
RUN python3 -m pip install -r requirements.txt
|
||||
|
||||
# Run the service loop.
|
||||
CMD ["python", "src/serve.py"]
|
||||
0
Dockerfile_tests
Executable file → Normal file
0
Dockerfile_tests
Executable file → Normal file
@ -44,7 +44,7 @@ public class PlanSpec {
|
||||
//By default credentials are read from the '.credentials' file.
|
||||
BambooServer bambooServer = new BambooServer("http://localhost:8085");
|
||||
|
||||
Plan plan = new PlanSpec().createDockerBuildPlan();
|
||||
Plan plan = new PlanSpec().createBuildPlan();
|
||||
bambooServer.publish(plan);
|
||||
PlanPermissions planPermission = new PlanSpec().createPlanPermission(plan.getIdentifier());
|
||||
bambooServer.publish(planPermission);
|
||||
@ -67,38 +67,12 @@ public class PlanSpec {
|
||||
.key(new BambooKey("RED"));
|
||||
}
|
||||
|
||||
public Plan createDockerBuildPlan() {
|
||||
public Plan createBuildPlan() {
|
||||
return new Plan(
|
||||
project(),
|
||||
SERVICE_NAME, new BambooKey(SERVICE_KEY))
|
||||
.description("Docker build for pyinfra")
|
||||
.description("Build for pyinfra")
|
||||
.stages(
|
||||
new Stage("Build Stage")
|
||||
.jobs(
|
||||
new Job("Build Job", new BambooKey("BUILD"))
|
||||
.tasks(
|
||||
new CleanWorkingDirectoryTask()
|
||||
.description("Clean working directory.")
|
||||
.enabled(true),
|
||||
new VcsCheckoutTask()
|
||||
.description("Checkout default repository.")
|
||||
.checkoutItems(new CheckoutItem().defaultRepository()),
|
||||
new ScriptTask()
|
||||
.description("Set config and keys.")
|
||||
.inlineBody("mkdir -p ~/.ssh\n" +
|
||||
"echo \"${bamboo.bamboo_agent_ssh}\" | base64 -d >> ~/.ssh/id_rsa\n" +
|
||||
"echo \"host vector.iqser.com\" > ~/.ssh/config\n" +
|
||||
"echo \" user bamboo-agent\" >> ~/.ssh/config\n" +
|
||||
"chmod 600 ~/.ssh/config ~/.ssh/id_rsa"),
|
||||
new ScriptTask()
|
||||
.description("Build Docker container.")
|
||||
.location(Location.FILE)
|
||||
.fileFromPath("bamboo-specs/src/main/resources/scripts/docker-build.sh")
|
||||
.argument(SERVICE_NAME))
|
||||
.dockerConfiguration(
|
||||
new DockerConfiguration()
|
||||
.image("nexus.iqser.com:5001/infra/release_build:4.2.0")
|
||||
.volume("/var/run/docker.sock", "/var/run/docker.sock"))),
|
||||
new Stage("Sonar Stage")
|
||||
.jobs(
|
||||
new Job("Sonar Job", new BambooKey("SONAR"))
|
||||
@ -120,12 +94,7 @@ public class PlanSpec {
|
||||
.description("Run Sonarqube scan.")
|
||||
.location(Location.FILE)
|
||||
.fileFromPath("bamboo-specs/src/main/resources/scripts/sonar-scan.sh")
|
||||
.argument(SERVICE_NAME),
|
||||
new ScriptTask()
|
||||
.description("Shut down any running docker containers.")
|
||||
.location(Location.FILE)
|
||||
.inlineBody("pip install docker-compose\n" +
|
||||
"docker-compose down"))
|
||||
.argument(SERVICE_NAME))
|
||||
.dockerConfiguration(
|
||||
new DockerConfiguration()
|
||||
.image("nexus.iqser.com:5001/infra/release_build:4.2.0")
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
SERVICE_NAME=$1
|
||||
|
||||
python3 -m venv build_venv
|
||||
source build_venv/bin/activate
|
||||
python3 -m pip install --upgrade pip
|
||||
|
||||
echo "index-url = https://${bamboo_nexus_user}:${bamboo_nexus_password}@nexus.iqser.com/repository/python-combind/simple" >> pip.conf
|
||||
docker build -f Dockerfile -t nexus.iqser.com:5001/red/$SERVICE_NAME:${bamboo_version_tag} .
|
||||
echo "${bamboo_nexus_password}" | docker login --username "${bamboo_nexus_user}" --password-stdin nexus.iqser.com:5001
|
||||
docker push nexus.iqser.com:5001/red/$SERVICE_NAME:${bamboo_version_tag}
|
||||
@ -10,17 +10,15 @@ python3 -m pip install dependency-check
|
||||
python3 -m pip install docker-compose
|
||||
python3 -m pip install coverage
|
||||
|
||||
echo "docker-compose down"
|
||||
docker-compose down
|
||||
sleep 30
|
||||
# This is disabled since there are currently no tests in this project.
|
||||
# If tests are added this can be enabled again
|
||||
# echo "coverage report generation"
|
||||
# bash run_tests.sh
|
||||
|
||||
echo "coverage report generation"
|
||||
bash run_tests.sh
|
||||
|
||||
if [ ! -f reports/coverage.xml ]
|
||||
then
|
||||
exit 1
|
||||
fi
|
||||
# if [ ! -f reports/coverage.xml ]
|
||||
# then
|
||||
# exit 1
|
||||
# fi
|
||||
|
||||
SERVICE_NAME=$1
|
||||
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
___ _ _ ___ __
|
||||
o O O | _ \ | || | |_ _| _ _ / _| _ _ __ _
|
||||
o | _/ \_, | | | | ' \ | _| | '_| / _` |
|
||||
TS__[O] _|_|_ _|__/ |___| |_||_| _|_|_ _|_|_ \__,_|
|
||||
{======|_| ``` |_| ````|_|`````|_|`````|_|`````|_|`````|_|`````|
|
||||
./o--000' `-0-0-' `-0-0-' `-0-0-' `-0-0-' `-0-0-' `-0-0-' `-0-0-'
|
||||
35
config.yaml
35
config.yaml
@ -1,35 +0,0 @@
|
||||
service:
|
||||
logging_level: $LOGGING_LEVEL_ROOT|DEBUG # Logging level for service logger
|
||||
|
||||
probing_webserver:
|
||||
host: $PROBING_WEBSERVER_HOST|"0.0.0.0" # Probe webserver address
|
||||
port: $PROBING_WEBSERVER_PORT|8080 # Probe webserver port
|
||||
mode: $PROBING_WEBSERVER_MODE|production # webserver mode: {development, production}
|
||||
|
||||
rabbitmq:
|
||||
host: $RABBITMQ_HOST|localhost # RabbitMQ host address
|
||||
port: $RABBITMQ_PORT|5672 # RabbitMQ host port
|
||||
user: $RABBITMQ_USERNAME|user # RabbitMQ username
|
||||
password: $RABBITMQ_PASSWORD|bitnami # RabbitMQ password
|
||||
heartbeat: $RABBITMQ_HEARTBEAT|7200 # Controls AMQP heartbeat timeout in seconds
|
||||
|
||||
queues:
|
||||
input: $REQUEST_QUEUE|request_queue # Requests to service
|
||||
output: $RESPONSE_QUEUE|response_queue # Responses by service
|
||||
dead_letter: $DEAD_LETTER_QUEUE|dead_letter_queue # Messages that failed to process
|
||||
|
||||
callback:
|
||||
analysis_endpoint: $ANALYSIS_ENDPOINT|"http://127.0.0.1:5000"
|
||||
|
||||
storage:
|
||||
backend: $STORAGE_BACKEND|s3 # The type of storage to use {s3, azure}
|
||||
bucket: "STORAGE_BUCKET_NAME|STORAGE_AZURECONTAINERNAME|pyinfra-test-bucket" # The bucket / container to pull files specified in queue requests from
|
||||
|
||||
s3:
|
||||
endpoint: $STORAGE_ENDPOINT|"http://127.0.0.1:9000"
|
||||
access_key: $STORAGE_KEY|root
|
||||
secret_key: $STORAGE_SECRET|password
|
||||
region: $STORAGE_REGION|"eu-west-1"
|
||||
|
||||
azure:
|
||||
connection_string: $STORAGE_AZURECONNECTIONSTRING|"DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net"
|
||||
@ -1,32 +0,0 @@
|
||||
version: '2'
|
||||
services:
|
||||
minio:
|
||||
image: minio/minio
|
||||
ports:
|
||||
- "9000:9000"
|
||||
environment:
|
||||
- MINIO_ROOT_PASSWORD=password
|
||||
- MINIO_ROOT_USER=root
|
||||
volumes:
|
||||
- ./data/minio_store:/data
|
||||
command: server /data
|
||||
network_mode: "bridge"
|
||||
rabbitmq:
|
||||
image: docker.io/bitnami/rabbitmq:3.9
|
||||
ports:
|
||||
- '4369:4369'
|
||||
- '5551:5551'
|
||||
- '5552:5552'
|
||||
- '5672:5672'
|
||||
- '25672:25672'
|
||||
- '15672:15672'
|
||||
environment:
|
||||
- RABBITMQ_SECURE_PASSWORD=yes
|
||||
- RABBITMQ_VM_MEMORY_HIGH_WATERMARK=100%
|
||||
- RABBITMQ_DISK_FREE_ABSOLUTE_LIMIT=20Gi
|
||||
network_mode: "bridge"
|
||||
volumes:
|
||||
- /opt/bitnami/rabbitmq/.rabbitmq/:/data/bitnami
|
||||
volumes:
|
||||
mdata:
|
||||
|
||||
@ -1,55 +1,58 @@
|
||||
"""Implements a config object with dot-indexing syntax."""
|
||||
import os
|
||||
from itertools import chain
|
||||
from operator import truth
|
||||
|
||||
from envyaml import EnvYAML
|
||||
from funcy import first, juxt, butlast, last
|
||||
|
||||
from pyinfra.locations import CONFIG_FILE
|
||||
from os import environ
|
||||
|
||||
|
||||
def _get_item_and_maybe_make_dotindexable(container, item):
|
||||
ret = container[item]
|
||||
return DotIndexable(ret) if isinstance(ret, dict) else ret
|
||||
def read_from_environment(environment_variable_name, default_value):
|
||||
return environ.get(environment_variable_name, default_value)
|
||||
|
||||
|
||||
class DotIndexable:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
class Config(object):
|
||||
def __init__(self):
|
||||
# Logging level for service logger
|
||||
self.logging_level_root = read_from_environment("LOGGING_LEVEL_ROOT", "DEBUG")
|
||||
|
||||
def __getattr__(self, item):
|
||||
return _get_item_and_maybe_make_dotindexable(self.x, item)
|
||||
# RabbitMQ host address
|
||||
self.rabbitmq_host = read_from_environment("RABBITMQ_HOST", "localhost")
|
||||
|
||||
def __repr__(self):
|
||||
return self.x.__repr__()
|
||||
# RabbitMQ host port
|
||||
self.rabbitmq_port = read_from_environment("RABBITMQ_PORT", "5672")
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.__getattr__(item)
|
||||
# RabbitMQ username
|
||||
self.rabbitmq_username = read_from_environment("RABBITMQ_USERNAME", "user")
|
||||
|
||||
# RabbitMQ password
|
||||
self.rabbitmq_password = read_from_environment("RABBITMQ_PASSWORD", "bitnami")
|
||||
|
||||
# Controls AMQP heartbeat timeout in seconds
|
||||
self.rabbitmq_heartbeat = read_from_environment("RABBITMQ_HEARTBEAT", "60")
|
||||
|
||||
# Queue name for requests to the service
|
||||
self.request_queue = read_from_environment("REQUEST_QUEUE", "request_queue")
|
||||
|
||||
# Queue name for responses by service
|
||||
self.response_queue = read_from_environment("RESPONSE_QUEUE", "response_queue")
|
||||
|
||||
# Queue name for failed messages
|
||||
self.dead_letter_queue = read_from_environment("DEAD_LETTER_QUEUE", "dead_letter_queue")
|
||||
|
||||
# The type of storage to use {s3, azure}
|
||||
self.storage_backend = read_from_environment("STORAGE_BACKEND", "s3")
|
||||
|
||||
# The bucket / container to pull files specified in queue requests from
|
||||
self.storage_bucket = read_from_environment("STORAGE_BUCKET_NAME", "redaction")
|
||||
|
||||
# Endpoint for s3 storage
|
||||
self.storage_endpoint = read_from_environment("STORAGE_ENDPOINT", "http://127.0.0.1:9000")
|
||||
|
||||
# User for s3 storage
|
||||
self.storage_key = read_from_environment("STORAGE_KEY", "root")
|
||||
|
||||
# Password for s3 storage
|
||||
self.storage_secret = read_from_environment("STORAGE_SECRET", "password")
|
||||
|
||||
# Connection string for Azure storage
|
||||
self.storage_azureconnectionstring = read_from_environment("STORAGE_AZURECONNECTIONSTRING",
|
||||
"DefaultEndpointsProtocol=...")
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, config_path):
|
||||
self.__config = EnvYAML(config_path)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item in self.__config:
|
||||
return _get_item_and_maybe_make_dotindexable(self.__config, item)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.__getattr__(item)
|
||||
|
||||
|
||||
CONFIG = Config(CONFIG_FILE)
|
||||
|
||||
|
||||
def parse_disjunction_string(disjunction_string):
|
||||
def try_parse_env_var(disjunction_string):
|
||||
try:
|
||||
return os.environ[disjunction_string]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
options = disjunction_string.split("|")
|
||||
identifiers, fallback_value = juxt(butlast, last)(options)
|
||||
return first(chain(filter(truth, map(try_parse_env_var, identifiers)), [fallback_value]))
|
||||
def get_config() -> Config:
|
||||
return Config()
|
||||
|
||||
@ -1,34 +0,0 @@
|
||||
class AnalysisFailure(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DataLoadingFailure(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ProcessingFailure(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownStorageBackend(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidEndpoint(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownClient(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ConsumerError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NoSuchContainer(KeyError):
|
||||
pass
|
||||
|
||||
|
||||
class IntentionalTestException(RuntimeError):
|
||||
pass
|
||||
@ -1,64 +0,0 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from flask import Flask, jsonify
|
||||
from waitress import serve
|
||||
|
||||
from pyinfra.config import CONFIG
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(CONFIG.service.logging_level)
|
||||
|
||||
|
||||
def run_probing_webserver(app, host=None, port=None, mode=None):
|
||||
if not host:
|
||||
host = CONFIG.probing_webserver.host
|
||||
|
||||
if not port:
|
||||
port = CONFIG.probing_webserver.port
|
||||
|
||||
if not mode:
|
||||
mode = CONFIG.probing_webserver.mode
|
||||
|
||||
if mode == "development":
|
||||
app.run(host=host, port=port, debug=True)
|
||||
|
||||
elif mode == "production":
|
||||
serve(app, host=host, port=port)
|
||||
|
||||
|
||||
def set_up_probing_webserver():
|
||||
# TODO: implement meaningful checks
|
||||
app = Flask(__name__)
|
||||
informed_about_missing_prometheus_endpoint = False
|
||||
|
||||
@app.route("/ready", methods=["GET"])
|
||||
def ready():
|
||||
resp = jsonify("OK")
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
@app.route("/health", methods=["GET"])
|
||||
def healthy():
|
||||
resp = jsonify("OK")
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
@app.route("/prometheus", methods=["GET"])
|
||||
def get_metrics_from_analysis_endpoint():
|
||||
nonlocal informed_about_missing_prometheus_endpoint
|
||||
try:
|
||||
resp = requests.get(f"{CONFIG.rabbitmq.callback.analysis_endpoint}/prometheus")
|
||||
resp.raise_for_status()
|
||||
except ConnectionError:
|
||||
return ""
|
||||
except requests.exceptions.HTTPError as err:
|
||||
if resp.status_code == 404:
|
||||
if not informed_about_missing_prometheus_endpoint:
|
||||
logger.warning(f"Got no metrics from analysis prometheus endpoint: {err}")
|
||||
informed_about_missing_prometheus_endpoint = True
|
||||
else:
|
||||
logging.warning(f"Caught {err}")
|
||||
return resp.text
|
||||
|
||||
return app
|
||||
@ -1,18 +0,0 @@
|
||||
"""Defines constant paths relative to the module root path."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
MODULE_DIR = Path(__file__).resolve().parents[0]
|
||||
|
||||
PACKAGE_ROOT_DIR = MODULE_DIR.parents[0]
|
||||
|
||||
TEST_DIR = PACKAGE_ROOT_DIR / "test"
|
||||
|
||||
CONFIG_FILE = PACKAGE_ROOT_DIR / "config.yaml"
|
||||
|
||||
TEST_CONFIG_FILE = TEST_DIR / "config.yaml"
|
||||
|
||||
COMPOSE_PATH = PACKAGE_ROOT_DIR
|
||||
|
||||
BANNER_FILE = PACKAGE_ROOT_DIR / "banner.txt"
|
||||
@ -1,16 +0,0 @@
|
||||
from pyinfra.queue.queue_manager.queue_manager import QueueManager
|
||||
|
||||
|
||||
class Consumer:
|
||||
def __init__(self, callback, queue_manager: QueueManager):
|
||||
self.queue_manager = queue_manager
|
||||
self.callback = callback
|
||||
|
||||
def consume_and_publish(self):
|
||||
self.queue_manager.consume_and_publish(self.callback)
|
||||
|
||||
def basic_consume_and_publish(self):
|
||||
self.queue_manager.basic_consume_and_publish(self.callback)
|
||||
|
||||
def consume(self, **kwargs):
|
||||
return self.queue_manager.consume(**kwargs)
|
||||
111
pyinfra/queue/queue_manager.py
Normal file
111
pyinfra/queue/queue_manager.py
Normal file
@ -0,0 +1,111 @@
|
||||
import atexit
|
||||
import json
|
||||
import logging
|
||||
import signal
|
||||
from typing import Callable
|
||||
|
||||
import pika
|
||||
import pika.exceptions
|
||||
|
||||
from pyinfra.config import get_config, Config
|
||||
|
||||
CONFIG = get_config()
|
||||
|
||||
pika_logger = logging.getLogger("pika")
|
||||
pika_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def get_connection_params(config: Config) -> pika.ConnectionParameters:
|
||||
credentials = pika.PlainCredentials(username=config.rabbitmq_username, password=config.rabbitmq_password)
|
||||
pika_connection_params = {
|
||||
"host": config.rabbitmq_host,
|
||||
"port": config.rabbitmq_port,
|
||||
"credentials": credentials,
|
||||
"heartbeat": int(config.rabbitmq_heartbeat),
|
||||
}
|
||||
|
||||
return pika.ConnectionParameters(**pika_connection_params)
|
||||
|
||||
|
||||
def _get_n_previous_attempts(props):
|
||||
return 0 if props.headers is None else props.headers.get("x-retry-count", 0)
|
||||
|
||||
|
||||
class QueueManager(object):
|
||||
def __init__(self, config: Config = CONFIG):
|
||||
connection_params = get_connection_params(config)
|
||||
|
||||
atexit.register(self.stop_consuming)
|
||||
signal.signal(signal.SIGTERM, self.stop_consuming)
|
||||
signal.signal(signal.SIGINT, self.stop_consuming)
|
||||
|
||||
self._connection = pika.BlockingConnection(parameters=connection_params)
|
||||
self._channel = self._connection.channel()
|
||||
self._channel.basic_qos(prefetch_count=1)
|
||||
|
||||
args = {"x-dead-letter-exchange": "", "x-dead-letter-routing-key": CONFIG.dead_letter_queue}
|
||||
|
||||
self._input_queue = config.request_queue
|
||||
self._output_queue = config.response_queue
|
||||
|
||||
self._channel.queue_declare(self._input_queue, arguments=args, auto_delete=False, durable=True)
|
||||
self._channel.queue_declare(self._output_queue, arguments=args, auto_delete=False, durable=True)
|
||||
|
||||
self._consumer_token = None
|
||||
|
||||
self.logger = logging.getLogger("queue_manager")
|
||||
self.logger.setLevel(CONFIG.logging_level_root)
|
||||
|
||||
def start_consuming(self, process_message_callback: Callable):
|
||||
callback = self._create_queue_callback(process_message_callback)
|
||||
|
||||
self.logger.info("Consuming from queue")
|
||||
self._consumer_token = None
|
||||
try:
|
||||
self._consumer_token = self._channel.basic_consume(self._input_queue, callback)
|
||||
self.logger.info(f"Registered with consumer-tag: {self._consumer_token}")
|
||||
self._channel.start_consuming()
|
||||
finally:
|
||||
self.logger.warning("An unhandled exception occurred while consuming messages. Consuming will stop.")
|
||||
self.stop_consuming()
|
||||
|
||||
def stop_consuming(self):
|
||||
if self._consumer_token and self._connection:
|
||||
self.logger.info(f"Cancelling subscription for consumer-tag: {self._consumer_token}")
|
||||
self._channel.basic_cancel(self._consumer_token)
|
||||
self._connection.close()
|
||||
|
||||
self._consumer_token = None
|
||||
|
||||
def _create_queue_callback(self, process_message_callback: Callable):
|
||||
def callback(_channel, frame, properties, body):
|
||||
self.logger.info(f"Received message from queue with delivery_tag {frame.delivery_tag}")
|
||||
self.logger.debug(f"Processing {(frame, properties, body)}.")
|
||||
|
||||
try:
|
||||
unpacked_message_body = json.loads(body)
|
||||
|
||||
callback_result = process_message_callback(unpacked_message_body)
|
||||
|
||||
self.logger.info("Processed message, publishing result to result-queue")
|
||||
self._channel.basic_publish("", self._output_queue, json.dumps(callback_result).encode())
|
||||
|
||||
self.logger.info(
|
||||
f"Result published, acknowledging incoming message with delivery_tag {frame.delivery_tag}")
|
||||
self._channel.basic_ack(frame.delivery_tag)
|
||||
|
||||
self.logger.info(f"Message with delivery_tag {frame.delivery_tag} processed")
|
||||
except Exception as ex:
|
||||
n_attempts = _get_n_previous_attempts(properties) + 1
|
||||
self.logger.warning(f"Failed to process message, {n_attempts} attempts, error: {str(ex)}")
|
||||
self._channel.basic_nack(frame.delivery_tag, requeue=False)
|
||||
raise ex
|
||||
|
||||
return callback
|
||||
|
||||
def clear(self):
|
||||
try:
|
||||
self._channel.queue_purge(self._input_queue)
|
||||
self._channel.queue_purge(self._output_queue)
|
||||
except pika.exceptions.ChannelWrongStateError:
|
||||
pass
|
||||
@ -1,168 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import pika
|
||||
|
||||
from pyinfra.config import CONFIG
|
||||
from pyinfra.exceptions import ProcessingFailure, DataLoadingFailure
|
||||
from pyinfra.queue.queue_manager.queue_manager import QueueHandle, QueueManager
|
||||
|
||||
logger = logging.getLogger("pika")
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(CONFIG.service.logging_level)
|
||||
|
||||
|
||||
def monkey_patch_queue_handle(channel, queue) -> QueueHandle:
|
||||
|
||||
empty_message = (None, None, None)
|
||||
|
||||
def is_empty_message(message):
|
||||
return message == empty_message
|
||||
|
||||
queue_handle = QueueHandle()
|
||||
queue_handle.empty = lambda: is_empty_message(channel.basic_get(queue))
|
||||
|
||||
def produce_items():
|
||||
|
||||
while True:
|
||||
message = channel.basic_get(queue)
|
||||
|
||||
if is_empty_message(message):
|
||||
break
|
||||
|
||||
method_frame, properties, body = message
|
||||
channel.basic_ack(method_frame.delivery_tag)
|
||||
yield json.loads(body)
|
||||
|
||||
queue_handle.to_list = lambda: list(produce_items())
|
||||
|
||||
return queue_handle
|
||||
|
||||
|
||||
def get_connection_params():
|
||||
|
||||
credentials = pika.PlainCredentials(username=CONFIG.rabbitmq.user, password=CONFIG.rabbitmq.password)
|
||||
kwargs = {
|
||||
"host": CONFIG.rabbitmq.host,
|
||||
"port": CONFIG.rabbitmq.port,
|
||||
"credentials": credentials,
|
||||
"heartbeat": CONFIG.rabbitmq.heartbeat,
|
||||
}
|
||||
parameters = pika.ConnectionParameters(**kwargs)
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def get_n_previous_attempts(props):
|
||||
return 0 if props.headers is None else props.headers.get("x-retry-count", 0)
|
||||
|
||||
|
||||
def attempts_remain(n_attempts, max_attempts):
|
||||
return n_attempts < max_attempts
|
||||
|
||||
|
||||
class PikaQueueManager(QueueManager):
|
||||
def __init__(self, input_queue, output_queue, dead_letter_queue=None, connection_params=None):
|
||||
super().__init__(input_queue, output_queue)
|
||||
|
||||
if not connection_params:
|
||||
connection_params = get_connection_params()
|
||||
|
||||
self.connection = pika.BlockingConnection(parameters=connection_params)
|
||||
self.channel = self.connection.channel()
|
||||
self.channel.basic_qos(prefetch_count=1)
|
||||
|
||||
if not dead_letter_queue:
|
||||
dead_letter_queue = CONFIG.rabbitmq.queues.dead_letter
|
||||
|
||||
args = {"x-dead-letter-exchange": "", "x-dead-letter-routing-key": dead_letter_queue}
|
||||
|
||||
self.channel.queue_declare(input_queue, arguments=args, auto_delete=False, durable=True)
|
||||
self.channel.queue_declare(output_queue, arguments=args, auto_delete=False, durable=True)
|
||||
|
||||
def republish(self, body, n_current_attempts, frame):
|
||||
self.channel.basic_publish(
|
||||
exchange="",
|
||||
routing_key=self._input_queue,
|
||||
body=body,
|
||||
properties=pika.BasicProperties(headers={"x-retry-count": n_current_attempts}),
|
||||
)
|
||||
self.channel.basic_ack(delivery_tag=frame.delivery_tag)
|
||||
|
||||
def publish_request(self, request):
|
||||
logger.debug(f"Publishing {request}")
|
||||
self.channel.basic_publish("", self._input_queue, json.dumps(request).encode())
|
||||
|
||||
def reject(self, body, frame):
|
||||
logger.error(f"Adding to dead letter queue: {body}")
|
||||
self.channel.basic_reject(delivery_tag=frame.delivery_tag, requeue=False)
|
||||
|
||||
def publish_response(self, message, callback, max_attempts=3):
|
||||
|
||||
logger.debug(f"Processing {message}.")
|
||||
|
||||
frame, properties, body = message
|
||||
|
||||
n_attempts = get_n_previous_attempts(properties) + 1
|
||||
|
||||
try:
|
||||
response = json.dumps(callback(json.loads(body)))
|
||||
self.channel.basic_publish("", self._output_queue, response.encode())
|
||||
self.channel.basic_ack(frame.delivery_tag)
|
||||
except (ProcessingFailure, DataLoadingFailure):
|
||||
|
||||
logger.error(f"Message failed to process {n_attempts}/{max_attempts} times: {body}")
|
||||
|
||||
if attempts_remain(n_attempts, max_attempts):
|
||||
self.republish(body, n_attempts, frame)
|
||||
else:
|
||||
self.reject(body, frame)
|
||||
|
||||
def pull_request(self):
|
||||
return self.channel.basic_get(self._input_queue)
|
||||
|
||||
def consume(self, inactivity_timeout=None):
|
||||
logger.debug("Consuming")
|
||||
return self.channel.consume(self._input_queue, inactivity_timeout=inactivity_timeout)
|
||||
|
||||
def consume_and_publish(self, visitor):
|
||||
|
||||
logger.info(f"Consuming with callback {visitor.callback.__name__}")
|
||||
|
||||
for message in self.consume():
|
||||
self.publish_response(message, visitor)
|
||||
|
||||
def basic_consume_and_publish(self, visitor):
|
||||
|
||||
logger.info(f"Basic consuming with callback {visitor.callback.__name__}")
|
||||
|
||||
def callback(channel, frame, properties, body):
|
||||
message = (frame, properties, body)
|
||||
return self.publish_response(message, visitor)
|
||||
|
||||
consumer_tag = None
|
||||
|
||||
try:
|
||||
consumer_tag = self.channel.basic_consume(self._input_queue, callback)
|
||||
self.channel.start_consuming()
|
||||
finally:
|
||||
if consumer_tag:
|
||||
self.channel.basic_cancel(consumer_tag)
|
||||
|
||||
def clear(self):
|
||||
try:
|
||||
self.channel.queue_purge(self._input_queue)
|
||||
self.channel.queue_purge(self._output_queue)
|
||||
except pika.exceptions.ChannelWrongStateError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def input_queue(self) -> QueueHandle:
|
||||
return monkey_patch_queue_handle(self.channel, self._input_queue)
|
||||
|
||||
@property
|
||||
def output_queue(self) -> QueueHandle:
|
||||
return monkey_patch_queue_handle(self.channel, self._output_queue)
|
||||
@ -1,51 +0,0 @@
|
||||
import abc
|
||||
|
||||
|
||||
class QueueHandle:
|
||||
def empty(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def to_list(self) -> list:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class QueueManager(abc.ABC):
|
||||
def __init__(self, input_queue, output_queue):
|
||||
self._input_queue = input_queue
|
||||
self._output_queue = output_queue
|
||||
|
||||
@abc.abstractmethod
|
||||
def publish_request(self, request):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def publish_response(self, response, callback):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def pull_request(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def consume(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def clear(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def input_queue(self) -> QueueHandle:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def output_queue(self) -> QueueHandle:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def consume_and_publish(self, callback):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def basic_consume_and_publish(self, callback):
|
||||
raise NotImplementedError
|
||||
@ -1,34 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class StorageAdapter(ABC):
|
||||
def __init__(self, client):
|
||||
self.__client = client
|
||||
|
||||
@abstractmethod
|
||||
def make_bucket(self, bucket_name):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def has_bucket(self, bucket_name):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def put_object(self, bucket_name, object_name, data):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_object(self, bucket_name, object_name):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_all_objects(self, bucket_name):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear_bucket(self, bucket_name):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_all_object_names(self, bucket_name):
|
||||
raise NotImplementedError
|
||||
@ -3,30 +3,32 @@ from itertools import repeat
|
||||
from operator import attrgetter
|
||||
|
||||
from azure.storage.blob import ContainerClient, BlobServiceClient
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.storage.adapters.adapter import StorageAdapter
|
||||
from pyinfra.config import Config, get_config
|
||||
|
||||
CONFIG = get_config()
|
||||
logger = logging.getLogger(CONFIG.logging_level_root)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger("azure").setLevel(logging.WARNING)
|
||||
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class AzureStorageAdapter(StorageAdapter):
|
||||
def __init__(self, client):
|
||||
super().__init__(client=client)
|
||||
self.__client: BlobServiceClient = self._StorageAdapter__client
|
||||
class AzureStorageAdapter(object):
|
||||
def __init__(self, client: BlobServiceClient):
|
||||
self._client: BlobServiceClient = client
|
||||
|
||||
def has_bucket(self, bucket_name):
|
||||
container_client = self.__client.get_container_client(bucket_name)
|
||||
container_client = self._client.get_container_client(bucket_name)
|
||||
return container_client.exists()
|
||||
|
||||
def make_bucket(self, bucket_name):
|
||||
container_client = self.__client.get_container_client(bucket_name)
|
||||
container_client if container_client.exists() else self.__client.create_container(bucket_name)
|
||||
container_client = self._client.get_container_client(bucket_name)
|
||||
container_client if container_client.exists() else self._client.create_container(bucket_name)
|
||||
|
||||
def __provide_container_client(self, bucket_name) -> ContainerClient:
|
||||
self.make_bucket(bucket_name)
|
||||
container_client = self.__client.get_container_client(bucket_name)
|
||||
container_client = self._client.get_container_client(bucket_name)
|
||||
return container_client
|
||||
|
||||
def put_object(self, bucket_name, object_name, data):
|
||||
@ -35,15 +37,19 @@ class AzureStorageAdapter(StorageAdapter):
|
||||
blob_client = container_client.get_blob_client(object_name)
|
||||
blob_client.upload_blob(data, overwrite=True)
|
||||
|
||||
@retry(tries=3, delay=5, jitter=(1, 3))
|
||||
def get_object(self, bucket_name, object_name):
|
||||
logger.debug(f"Downloading '{object_name}'...")
|
||||
container_client = self.__provide_container_client(bucket_name)
|
||||
blob_client = container_client.get_blob_client(object_name)
|
||||
blob_data = blob_client.download_blob()
|
||||
return blob_data.readall()
|
||||
|
||||
try:
|
||||
container_client = self.__provide_container_client(bucket_name)
|
||||
blob_client = container_client.get_blob_client(object_name)
|
||||
blob_data = blob_client.download_blob()
|
||||
return blob_data.readall()
|
||||
except Exception as err:
|
||||
raise Exception("Failed getting object from azure client") from err
|
||||
|
||||
def get_all_objects(self, bucket_name):
|
||||
|
||||
container_client = self.__provide_container_client(bucket_name)
|
||||
blobs = container_client.list_blobs()
|
||||
for blob in blobs:
|
||||
@ -55,7 +61,7 @@ class AzureStorageAdapter(StorageAdapter):
|
||||
|
||||
def clear_bucket(self, bucket_name):
|
||||
logger.debug(f"Clearing Azure container '{bucket_name}'...")
|
||||
container_client = self.__client.get_container_client(bucket_name)
|
||||
container_client = self._client.get_container_client(bucket_name)
|
||||
blobs = container_client.list_blobs()
|
||||
container_client.delete_blobs(*blobs)
|
||||
|
||||
@ -63,3 +69,8 @@ class AzureStorageAdapter(StorageAdapter):
|
||||
container_client = self.__provide_container_client(bucket_name)
|
||||
blobs = container_client.list_blobs()
|
||||
return zip(repeat(bucket_name), map(attrgetter("name"), blobs))
|
||||
|
||||
|
||||
def get_azure_storage(config: Config):
|
||||
return AzureStorageAdapter(
|
||||
BlobServiceClient.from_connection_string(conn_str=config.storage_azureconnectionstring))
|
||||
|
||||
@ -1,58 +1,88 @@
|
||||
import io
|
||||
from itertools import repeat
|
||||
import logging
|
||||
import re
|
||||
from itertools import repeat
|
||||
from operator import attrgetter
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from minio import Minio
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.exceptions import DataLoadingFailure
|
||||
from pyinfra.storage.adapters.adapter import StorageAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from pyinfra.config import Config, get_config
|
||||
|
||||
|
||||
class S3StorageAdapter(StorageAdapter):
|
||||
def __init__(self, client):
|
||||
super().__init__(client=client)
|
||||
self.__client: Minio = self._StorageAdapter__client
|
||||
CONFIG = get_config()
|
||||
logger = logging.getLogger(CONFIG.logging_level_root)
|
||||
|
||||
ALLOWED_CONNECTION_SCHEMES = {"http", "https"}
|
||||
URL_VALIDATOR = re.compile(
|
||||
r"^((" +
|
||||
r"([A-Za-z]{3,9}:(?:\/\/)?)" +
|
||||
r"(?:[\-;:&=\+\$,\w]+@)?" + r"[A-Za-z0-9\.\-]+|(?:www\.|[\-;:&=\+\$,\w]+@)" +
|
||||
r"[A-Za-z0-9\.\-]+)" + r"((?:\/[\+~%\/\.\w\-_]*)?" +
|
||||
r"\??(?:[\-\+=&;%@\.\w_]*)#?(?:[\.\!\/\\\w]*))?)")
|
||||
|
||||
|
||||
class S3StorageAdapter(object):
|
||||
def __init__(self, client: Minio):
|
||||
self._client = client
|
||||
|
||||
def make_bucket(self, bucket_name):
|
||||
if not self.has_bucket(bucket_name):
|
||||
self.__client.make_bucket(bucket_name)
|
||||
self._client.make_bucket(bucket_name)
|
||||
|
||||
def has_bucket(self, bucket_name):
|
||||
return self.__client.bucket_exists(bucket_name)
|
||||
return self._client.bucket_exists(bucket_name)
|
||||
|
||||
def put_object(self, bucket_name, object_name, data):
|
||||
logger.debug(f"Uploading '{object_name}'...")
|
||||
data = io.BytesIO(data)
|
||||
self.__client.put_object(bucket_name, object_name, data, length=data.getbuffer().nbytes)
|
||||
self._client.put_object(bucket_name, object_name, data, length=data.getbuffer().nbytes)
|
||||
|
||||
@retry(tries=3, delay=5, jitter=(1, 3))
|
||||
def get_object(self, bucket_name, object_name):
|
||||
logger.debug(f"Downloading '{object_name}'...")
|
||||
response = None
|
||||
|
||||
try:
|
||||
response = self.__client.get_object(bucket_name, object_name)
|
||||
response = self._client.get_object(bucket_name, object_name)
|
||||
return response.data
|
||||
except Exception as err:
|
||||
raise DataLoadingFailure("Failed getting object from s3 client") from err
|
||||
raise Exception("Failed getting object from s3 client") from err
|
||||
finally:
|
||||
if response:
|
||||
response.close()
|
||||
response.release_conn()
|
||||
|
||||
def get_all_objects(self, bucket_name):
|
||||
for obj in self.__client.list_objects(bucket_name, recursive=True):
|
||||
for obj in self._client.list_objects(bucket_name, recursive=True):
|
||||
logger.debug(f"Downloading '{obj.object_name}'...")
|
||||
yield self.get_object(bucket_name, obj.object_name)
|
||||
|
||||
def clear_bucket(self, bucket_name):
|
||||
logger.debug(f"Clearing S3 bucket '{bucket_name}'...")
|
||||
objects = self.__client.list_objects(bucket_name, recursive=True)
|
||||
objects = self._client.list_objects(bucket_name, recursive=True)
|
||||
for obj in objects:
|
||||
self.__client.remove_object(bucket_name, obj.object_name)
|
||||
self._client.remove_object(bucket_name, obj.object_name)
|
||||
|
||||
def get_all_object_names(self, bucket_name):
|
||||
objs = self.__client.list_objects(bucket_name, recursive=True)
|
||||
objs = self._client.list_objects(bucket_name, recursive=True)
|
||||
return zip(repeat(bucket_name), map(attrgetter("object_name"), objs))
|
||||
|
||||
|
||||
def _parse_endpoint(endpoint):
|
||||
parsed_url = urlparse(endpoint)
|
||||
if URL_VALIDATOR.match(endpoint) and parsed_url.netloc and parsed_url.scheme in ALLOWED_CONNECTION_SCHEMES:
|
||||
return {"secure": parsed_url.scheme == "https", "endpoint": parsed_url.netloc}
|
||||
else:
|
||||
raise Exception(f"The configured storage endpoint is not a valid url: {endpoint}")
|
||||
|
||||
|
||||
def get_s3_storage(config: Config):
|
||||
return S3StorageAdapter(Minio(
|
||||
**_parse_endpoint(config.storage_endpoint),
|
||||
access_key=config.storage_key,
|
||||
secret_key=config.storage_secret,
|
||||
# FIXME Is this still needed? Check and if yes, add it to config
|
||||
# region=config.region,
|
||||
))
|
||||
|
||||
@ -1,11 +0,0 @@
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
|
||||
from pyinfra.config import CONFIG
|
||||
|
||||
|
||||
def get_azure_client(connection_string=None) -> BlobServiceClient:
|
||||
|
||||
if not connection_string:
|
||||
connection_string = CONFIG.storage.azure.connection_string
|
||||
|
||||
return BlobServiceClient.from_connection_string(conn_str=connection_string)
|
||||
@ -1,40 +0,0 @@
|
||||
import re
|
||||
|
||||
from minio import Minio
|
||||
|
||||
from pyinfra.config import CONFIG
|
||||
from pyinfra.exceptions import InvalidEndpoint
|
||||
|
||||
|
||||
def parse_endpoint(endpoint):
|
||||
# FIXME Greedy matching (.+) since we get random storage names on kubernetes (eg http://red-research-headless:9000)
|
||||
# FIXME this has been broken and accepts invalid URLs
|
||||
endpoint_pattern = r"(?P<protocol>https?)*(?:://)*(?P<address>(?:(?:(?:\d{1,3}\.){3}\d{1,3})|.+)(?:\:\d+)?)"
|
||||
|
||||
match = re.match(endpoint_pattern, endpoint)
|
||||
|
||||
if not match:
|
||||
raise InvalidEndpoint(f"Endpoint {endpoint} is invalid; expected {endpoint_pattern}")
|
||||
|
||||
return {"secure": match.group("protocol") == "https", "endpoint": match.group("address")}
|
||||
|
||||
|
||||
def get_s3_client(params=None) -> Minio:
|
||||
"""
|
||||
Args:
|
||||
params: dict like
|
||||
{
|
||||
"endpoint": <storage_endpoint>
|
||||
"access_key": <storage_key>
|
||||
"secret_key": <storage_secret>
|
||||
}
|
||||
"""
|
||||
if not params:
|
||||
params = CONFIG.storage.s3
|
||||
|
||||
return Minio(
|
||||
**parse_endpoint(params.endpoint),
|
||||
access_key=params.access_key,
|
||||
secret_key=params.secret_key,
|
||||
region=params.region,
|
||||
)
|
||||
@ -1,44 +1,21 @@
|
||||
import logging
|
||||
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.config import CONFIG
|
||||
from pyinfra.exceptions import DataLoadingFailure
|
||||
from pyinfra.storage.adapters.adapter import StorageAdapter
|
||||
from pyinfra.config import get_config, Config
|
||||
from pyinfra.storage.adapters.azure import get_azure_storage
|
||||
from pyinfra.storage.adapters.s3 import get_s3_storage
|
||||
|
||||
CONFIG = get_config()
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(CONFIG.service.logging_level)
|
||||
logger.setLevel(CONFIG.logging_level_root)
|
||||
|
||||
|
||||
class Storage:
|
||||
def __init__(self, adapter: StorageAdapter):
|
||||
self.__adapter = adapter
|
||||
def get_storage(config: Config):
|
||||
|
||||
def make_bucket(self, bucket_name):
|
||||
self.__adapter.make_bucket(bucket_name)
|
||||
if config.storage_backend == "s3":
|
||||
storage = get_s3_storage(config)
|
||||
elif config.storage_backend == "azure":
|
||||
storage = get_azure_storage(config)
|
||||
else:
|
||||
raise Exception(f"Unknown storage backend '{config.storage_backend}'.")
|
||||
|
||||
def has_bucket(self, bucket_name):
|
||||
return self.__adapter.has_bucket(bucket_name)
|
||||
|
||||
def put_object(self, bucket_name, object_name, data):
|
||||
self.__adapter.put_object(bucket_name, object_name, data)
|
||||
|
||||
def get_object(self, bucket_name, object_name):
|
||||
return self.__get_object(bucket_name, object_name)
|
||||
|
||||
@retry(DataLoadingFailure, tries=3, delay=5, jitter=(1, 3))
|
||||
def __get_object(self, bucket_name, object_name):
|
||||
try:
|
||||
return self.__adapter.get_object(bucket_name, object_name)
|
||||
except Exception as err:
|
||||
logging.error(err)
|
||||
raise DataLoadingFailure from err
|
||||
|
||||
def get_all_objects(self, bucket_name):
|
||||
return self.__adapter.get_all_objects(bucket_name)
|
||||
|
||||
def clear_bucket(self, bucket_name):
|
||||
return self.__adapter.clear_bucket(bucket_name)
|
||||
|
||||
def get_all_object_names(self, bucket_name):
|
||||
return self.__adapter.get_all_object_names(bucket_name)
|
||||
return storage
|
||||
|
||||
@ -1,26 +0,0 @@
|
||||
from pyinfra.exceptions import UnknownStorageBackend
|
||||
from pyinfra.storage.adapters.azure import AzureStorageAdapter
|
||||
from pyinfra.storage.adapters.s3 import S3StorageAdapter
|
||||
from pyinfra.storage.clients.azure import get_azure_client
|
||||
from pyinfra.storage.clients.s3 import get_s3_client
|
||||
from pyinfra.storage.storage import Storage
|
||||
|
||||
|
||||
def get_azure_storage(config=None):
|
||||
return Storage(AzureStorageAdapter(get_azure_client(config)))
|
||||
|
||||
|
||||
def get_s3_storage(config=None):
|
||||
return Storage(S3StorageAdapter(get_s3_client(config)))
|
||||
|
||||
|
||||
def get_storage(storage_backend):
|
||||
|
||||
if storage_backend == "s3":
|
||||
storage = get_s3_storage()
|
||||
elif storage_backend == "azure":
|
||||
storage = get_azure_storage()
|
||||
else:
|
||||
raise UnknownStorageBackend(f"Unknown storage backend '{storage_backend}'.")
|
||||
|
||||
return storage
|
||||
@ -1,21 +0,0 @@
|
||||
import logging
|
||||
|
||||
from pyinfra.locations import BANNER_FILE
|
||||
|
||||
|
||||
def show_banner():
|
||||
with open(BANNER_FILE) as f:
|
||||
banner = "\n" + "".join(f.readlines()) + "\n"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.propagate = False
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter("")
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
logger.info(banner)
|
||||
@ -1,91 +0,0 @@
|
||||
import abc
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
from typing import Callable
|
||||
|
||||
from pyinfra.config import CONFIG, parse_disjunction_string
|
||||
from pyinfra.exceptions import DataLoadingFailure
|
||||
from pyinfra.storage.storage import Storage
|
||||
|
||||
|
||||
def get_object_name(body):
|
||||
dossier_id, file_id, target_file_extension = itemgetter("dossierId", "fileId", "targetFileExtension")(body)
|
||||
object_name = f"{dossier_id}/{file_id}.{target_file_extension}"
|
||||
return object_name
|
||||
|
||||
|
||||
def get_response_object_name(body):
|
||||
dossier_id, file_id, response_file_extension = itemgetter("dossierId", "fileId", "responseFileExtension")(body)
|
||||
object_name = f"{dossier_id}/{file_id}.{response_file_extension}"
|
||||
return object_name
|
||||
|
||||
|
||||
def get_object_descriptor(body):
|
||||
return {"bucket_name": parse_disjunction_string(CONFIG.storage.bucket), "object_name": get_object_name(body)}
|
||||
|
||||
|
||||
def get_response_object_descriptor(body):
|
||||
return {
|
||||
"bucket_name": parse_disjunction_string(CONFIG.storage.bucket),
|
||||
"object_name": get_response_object_name(body),
|
||||
}
|
||||
|
||||
|
||||
class ResponseStrategy(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def handle_response(self, body):
|
||||
pass
|
||||
|
||||
def __call__(self, body):
|
||||
return self.handle_response(body)
|
||||
|
||||
|
||||
class StorageStrategy(ResponseStrategy):
|
||||
def __init__(self, storage):
|
||||
self.storage = storage
|
||||
|
||||
def handle_response(self, body):
|
||||
self.storage.put_object(**get_response_object_descriptor(body), data=gzip.compress(json.dumps(body).encode()))
|
||||
body.pop("data")
|
||||
return body
|
||||
|
||||
|
||||
class ForwardingStrategy(ResponseStrategy):
|
||||
def handle_response(self, body):
|
||||
return body
|
||||
|
||||
|
||||
class QueueVisitor:
|
||||
def __init__(self, storage: Storage, callback: Callable, response_strategy):
|
||||
self.storage = storage
|
||||
self.callback = callback
|
||||
self.response_strategy = response_strategy
|
||||
|
||||
def load_data(self, body):
|
||||
def download():
|
||||
logging.debug(f"Downloading {object_descriptor}...")
|
||||
data = self.storage.get_object(**object_descriptor)
|
||||
logging.debug(f"Downloaded {object_descriptor}.")
|
||||
return data
|
||||
|
||||
object_descriptor = get_object_descriptor(body)
|
||||
|
||||
try:
|
||||
return gzip.decompress(download())
|
||||
except Exception as err:
|
||||
logging.warning(f"Loading data from storage failed for {object_descriptor}.")
|
||||
raise DataLoadingFailure from err
|
||||
|
||||
def process_data(self, data, body):
|
||||
return self.callback({**body, "data": data})
|
||||
|
||||
def load_and_process(self, body):
|
||||
data = self.process_data(self.load_data(body), body)
|
||||
result_body = {**body, "data": data}
|
||||
return result_body
|
||||
|
||||
def __call__(self, body):
|
||||
result_body = self.load_and_process(body)
|
||||
return self.response_strategy(result_body)
|
||||
@ -1,15 +1,9 @@
|
||||
pika==1.2.0
|
||||
retry==0.9.2
|
||||
envyaml==1.10.211231
|
||||
minio==7.1.3
|
||||
Flask==2.1.1
|
||||
waitress==2.0.0
|
||||
azure-core==1.22.1
|
||||
azure-storage-blob==12.9.0
|
||||
requests==2.27.1
|
||||
testcontainers==3.4.2
|
||||
docker-compose==1.29.2
|
||||
tqdm==4.62.3
|
||||
pytest~=7.0.1
|
||||
funcy==1.17
|
||||
fpdf==1.7.2
|
||||
|
||||
0
run_tests.sh
Executable file → Normal file
0
run_tests.sh
Executable file → Normal file
@ -1,72 +0,0 @@
|
||||
import argparse
|
||||
import gzip
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from pyinfra.config import CONFIG, parse_disjunction_string
|
||||
from pyinfra.storage.storages import get_s3_storage
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
subparsers = parser.add_subparsers(help="sub-command help", dest="command")
|
||||
|
||||
parser_add = subparsers.add_parser("add", help="Add file(s) to the MinIO store")
|
||||
parser_add.add_argument("dossier_id")
|
||||
add_group = parser_add.add_mutually_exclusive_group(required=True)
|
||||
add_group.add_argument("--file", "-f")
|
||||
add_group.add_argument("--directory", "-d")
|
||||
|
||||
subparsers.add_parser("purge", help="Delete all files and buckets in the MinIO store")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def combine_dossier_id_and_file_id_and_extension(dossier_id, file_id, extension):
|
||||
return f"{dossier_id}/{file_id}{extension}"
|
||||
|
||||
|
||||
def upload_compressed_response(storage, bucket_name, dossier_id, file_id, result) -> None:
|
||||
data = gzip.compress(result.encode())
|
||||
path_gz = combine_dossier_id_and_file_id_and_extension(dossier_id, file_id, CONFIG.service.response.extension)
|
||||
storage.put_object(bucket_name, path_gz, data)
|
||||
|
||||
|
||||
def add_file_compressed(storage, bucket_name, dossier_id, path) -> None:
|
||||
if Path(path).suffix == ".pdf":
|
||||
suffix_gz = ".ORIGIN.pdf.gz"
|
||||
if Path(path).suffix == ".json":
|
||||
suffix_gz = ".TEXT.json.gz"
|
||||
path_gz = combine_dossier_id_and_file_id_and_extension(dossier_id, Path(path).stem, suffix_gz)
|
||||
|
||||
with open(path, "rb") as f:
|
||||
data = gzip.compress(f.read())
|
||||
storage.put_object(bucket_name, path_gz, data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
storage = get_s3_storage()
|
||||
bucket_name = parse_disjunction_string(CONFIG.storage.bucket)
|
||||
|
||||
if not storage.has_bucket(bucket_name):
|
||||
storage.make_bucket(bucket_name)
|
||||
|
||||
args = parse_args()
|
||||
|
||||
if args.command == "add":
|
||||
|
||||
if args.file:
|
||||
add_file_compressed(storage, bucket_name, args.dossier_id, args.file)
|
||||
|
||||
elif args.directory:
|
||||
for fname in tqdm([*os.listdir(args.directory)], desc="Adding files"):
|
||||
path = Path(args.directory) / fname
|
||||
add_file_compressed(storage, bucket_name, args.dossier_id, path)
|
||||
|
||||
elif args.command == "purge":
|
||||
storage.clear_bucket(bucket_name)
|
||||
@ -1,88 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import pika
|
||||
|
||||
from pyinfra.config import CONFIG, parse_disjunction_string
|
||||
from pyinfra.storage.storages import get_s3_storage
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--bucket_name", "-b", required=True)
|
||||
parser.add_argument("--analysis_container", "-a", choices=["detr", "ner", "image", "dl_error"], required=True)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def read_connection_params():
|
||||
credentials = pika.PlainCredentials(CONFIG.rabbitmq.user, CONFIG.rabbitmq.password)
|
||||
parameters = pika.ConnectionParameters(
|
||||
host=CONFIG.rabbitmq.host,
|
||||
port=CONFIG.rabbitmq.port,
|
||||
heartbeat=CONFIG.rabbitmq.heartbeat,
|
||||
credentials=credentials,
|
||||
)
|
||||
return parameters
|
||||
|
||||
|
||||
def make_channel(connection) -> pika.adapters.blocking_connection.BlockingChannel:
|
||||
channel = connection.channel()
|
||||
channel.basic_qos(prefetch_count=1)
|
||||
return channel
|
||||
|
||||
|
||||
def declare_queue(channel, queue: str):
|
||||
args = {"x-dead-letter-exchange": "", "x-dead-letter-routing-key": CONFIG.rabbitmq.queues.dead_letter}
|
||||
return channel.queue_declare(queue=queue, auto_delete=False, durable=True, arguments=args)
|
||||
|
||||
|
||||
def make_connection() -> pika.BlockingConnection:
|
||||
parameters = read_connection_params()
|
||||
connection = pika.BlockingConnection(parameters)
|
||||
return connection
|
||||
|
||||
|
||||
def build_message_bodies(analyse_container_type, bucket_name):
|
||||
def update_message(message_dict):
|
||||
if analyse_container_type == "detr" or analyse_container_type == "image":
|
||||
message_dict.update({"targetFileExtension": "ORIGIN.pdf.gz", "responseFileExtension": "IMAGE_INFO.json.gz"})
|
||||
if analyse_container_type == "dl_error":
|
||||
message_dict.update({"targetFileExtension": "no_such_file", "responseFileExtension": "IMAGE_INFO.json.gz"})
|
||||
if analyse_container_type == "ner":
|
||||
message_dict.update(
|
||||
{"targetFileExtension": "TEXT.json.gz", "responseFileExtension": "NER_ENTITIES.json.gz"}
|
||||
)
|
||||
return message_dict
|
||||
|
||||
storage = get_s3_storage()
|
||||
for bucket_name, pdf_name in storage.get_all_object_names(bucket_name):
|
||||
if "pdf" not in pdf_name:
|
||||
continue
|
||||
file_id = pdf_name.split(".")[0]
|
||||
dossier_id, file_id = file_id.split("/")
|
||||
message_dict = {"dossierId": dossier_id, "fileId": file_id}
|
||||
update_message(message_dict)
|
||||
yield json.dumps(message_dict).encode()
|
||||
|
||||
|
||||
def main(args):
|
||||
connection = make_connection()
|
||||
channel = make_channel(connection)
|
||||
declare_queue(channel, CONFIG.rabbitmq.queues.input)
|
||||
declare_queue(channel, CONFIG.rabbitmq.queues.output)
|
||||
|
||||
for body in build_message_bodies(args.analysis_container, args.bucket_name):
|
||||
channel.basic_publish("", CONFIG.rabbitmq.queues.input, body)
|
||||
print(f"Put {body} on {CONFIG.rabbitmq.queues.input}")
|
||||
|
||||
for method_frame, _, body in channel.consume(queue=CONFIG.rabbitmq.queues.output, inactivity_timeout=1):
|
||||
if not body:
|
||||
break
|
||||
print(f"Received {json.loads(body)}")
|
||||
channel.basic_ack(method_frame.delivery_tag)
|
||||
channel.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(parse_args())
|
||||
84
src/serve.py
84
src/serve.py
@ -1,84 +0,0 @@
|
||||
import logging
|
||||
from multiprocessing import Process
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.config import CONFIG
|
||||
from pyinfra.exceptions import AnalysisFailure, ConsumerError
|
||||
from pyinfra.flask import run_probing_webserver, set_up_probing_webserver
|
||||
from pyinfra.queue.consumer import Consumer
|
||||
from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager
|
||||
from pyinfra.storage.storages import get_storage
|
||||
from pyinfra.utils.banner import show_banner
|
||||
from pyinfra.visitor import QueueVisitor, StorageStrategy
|
||||
|
||||
|
||||
def make_callback(analysis_endpoint):
|
||||
def callback(message):
|
||||
def perform_operation(operation):
|
||||
endpoint = f"{analysis_endpoint}/{operation}"
|
||||
try:
|
||||
logging.debug(f"Requesting analysis from {endpoint}...")
|
||||
analysis_response = requests.post(endpoint, data=message["data"])
|
||||
analysis_response.raise_for_status()
|
||||
analysis_response = analysis_response.json()
|
||||
logging.debug(f"Received response.")
|
||||
return analysis_response
|
||||
except Exception as err:
|
||||
logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.")
|
||||
raise AnalysisFailure() from err
|
||||
|
||||
operations = message.get("operations", ["/"])
|
||||
results = map(perform_operation, operations)
|
||||
result = dict(zip(operations, results))
|
||||
if list(result.keys()) == ["/"]:
|
||||
result = list(result.values())[0]
|
||||
return result
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
def main():
|
||||
logger = logging.getLogger("main")
|
||||
|
||||
show_banner()
|
||||
|
||||
webserver = Process(target=run_probing_webserver, args=(set_up_probing_webserver(),))
|
||||
logging.info("Starting webserver...")
|
||||
webserver.start()
|
||||
|
||||
callback = make_callback(CONFIG.rabbitmq.callback.analysis_endpoint)
|
||||
storage = get_storage(CONFIG.storage.backend)
|
||||
response_strategy = StorageStrategy(storage)
|
||||
visitor = QueueVisitor(storage, callback, response_strategy)
|
||||
|
||||
@retry(ConsumerError, tries=3, delay=5, jitter=(1, 3))
|
||||
def consume():
|
||||
try: # RED-4049 queue manager needs to be in try scope to eventually throw Exception after connection loss.
|
||||
queue_manager = PikaQueueManager(CONFIG.rabbitmq.queues.input, CONFIG.rabbitmq.queues.output)
|
||||
consumer = Consumer(visitor, queue_manager)
|
||||
consumer.basic_consume_and_publish()
|
||||
except Exception as err:
|
||||
logger.exception(err)
|
||||
raise ConsumerError from err
|
||||
|
||||
try:
|
||||
consume()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except ConsumerError:
|
||||
webserver.terminate()
|
||||
raise
|
||||
|
||||
webserver.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging_level = CONFIG.service.logging_level
|
||||
logging.basicConfig(level=logging_level)
|
||||
logging.getLogger("pika").setLevel(logging.ERROR)
|
||||
logging.getLogger("flask").setLevel(logging.ERROR)
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
|
||||
main()
|
||||
@ -1,5 +0,0 @@
|
||||
from pyinfra.config import Config
|
||||
from pyinfra.locations import TEST_CONFIG_FILE
|
||||
|
||||
|
||||
CONFIG = Config(TEST_CONFIG_FILE)
|
||||
@ -1,25 +0,0 @@
|
||||
storage:
|
||||
minio:
|
||||
endpoint: "http://127.0.0.1:9000"
|
||||
access_key: root
|
||||
secret_key: password
|
||||
region: null
|
||||
|
||||
aws:
|
||||
endpoint: https://s3.amazonaws.com
|
||||
access_key: AKIA4QVP6D4LCDAGYGN2
|
||||
secret_key: 8N6H1TUHTsbvW2qMAm7zZlJ63hMqjcXAsdN7TYED
|
||||
region: $STORAGE_REGION|"eu-west-1"
|
||||
|
||||
azure:
|
||||
connection_string: "DefaultEndpointsProtocol=https;AccountName=iqserdevelopment;AccountKey=4imAbV9PYXaztSOMpIyAClg88bAZCXuXMGJG0GA1eIBpdh2PlnFGoRBnKqLy2YZUSTmZ3wJfC7tzfHtuC6FEhQ==;EndpointSuffix=core.windows.net"
|
||||
|
||||
bucket: "pyinfra-test-bucket"
|
||||
|
||||
webserver:
|
||||
host: $SERVER_HOST|"127.0.0.1" # webserver address
|
||||
port: $SERVER_PORT|5000 # webserver port
|
||||
mode: $SERVER_MODE|production # webserver mode: {development, production}
|
||||
|
||||
|
||||
mock_analysis_endpoint: "http://127.0.0.1:5000"
|
||||
@ -1,79 +0,0 @@
|
||||
import json
|
||||
from operator import itemgetter
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request, jsonify
|
||||
import fpdf
|
||||
|
||||
|
||||
def set_up_processing_server():
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/ready", methods=["GET"])
|
||||
def ready():
|
||||
resp = jsonify("OK")
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
@app.route("/process", methods=["POST"])
|
||||
def process():
|
||||
payload = json.loads(request.json)
|
||||
data = payload["data"].encode()
|
||||
metadata = payload["metadata"]
|
||||
|
||||
response_payload = {"metadata_type": str(type(metadata)), "data_type": str(type(data))}
|
||||
|
||||
return jsonify(response_payload)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server():
|
||||
server = set_up_processing_server()
|
||||
server.config.update({"TESTING": True})
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(server):
|
||||
return server.test_client()
|
||||
|
||||
|
||||
def test_server_ready_check(client):
|
||||
response = client.get("/ready")
|
||||
assert response.status_code == 200
|
||||
assert response.json == "OK"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("data_type", ["pdf", "bytestring"])
|
||||
def test_sending_bytes_through_json(client, data):
|
||||
payload = {"data": data.decode("latin1"), "metadata": {"A": 1, "B": [2, 3]}}
|
||||
|
||||
response = client.post("/process", json=json.dumps(payload))
|
||||
|
||||
response_payload = response.json
|
||||
data_type, metadata_type = itemgetter("data_type", "metadata_type")(response_payload)
|
||||
|
||||
assert data_type == "<class 'bytes'>"
|
||||
assert metadata_type == "<class 'dict'>"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pdf():
|
||||
pdf = fpdf.FPDF(unit="pt")
|
||||
pdf.add_page()
|
||||
|
||||
return pdf_stream(pdf)
|
||||
|
||||
|
||||
def pdf_stream(pdf: fpdf.fpdf.FPDF):
|
||||
return pdf.output(dest="S").encode("latin1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data(data_type, pdf):
|
||||
if data_type == "pdf":
|
||||
return pdf
|
||||
elif data_type == "bytestring":
|
||||
return "content".encode("latin1")
|
||||
@ -1,46 +0,0 @@
|
||||
from pyinfra.queue.queue_manager.queue_manager import QueueManager, QueueHandle
|
||||
from test.queue.queue_mock import QueueMock
|
||||
|
||||
|
||||
def monkey_patch_queue_handle(queue) -> QueueHandle:
|
||||
queue_handle = QueueHandle()
|
||||
queue_handle.empty = lambda: not queue
|
||||
queue_handle.to_list = lambda: list(queue)
|
||||
return queue_handle
|
||||
|
||||
|
||||
class QueueManagerMock(QueueManager):
|
||||
def __init__(self, input_queue, output_queue):
|
||||
super().__init__(QueueMock(), QueueMock())
|
||||
|
||||
def publish_request(self, request):
|
||||
self._input_queue.append(request)
|
||||
|
||||
def publish_response(self, message, callback):
|
||||
self._output_queue.append(callback(message))
|
||||
|
||||
def pull_request(self):
|
||||
return self._input_queue.popleft()
|
||||
|
||||
def consume(self, **kwargs):
|
||||
while self._input_queue:
|
||||
yield self.pull_request()
|
||||
|
||||
def consume_and_publish(self, callback):
|
||||
for message in self.consume():
|
||||
self.publish_response(message, callback)
|
||||
|
||||
def basic_consume_and_publish(self, callback):
|
||||
raise NotImplementedError
|
||||
|
||||
def clear(self):
|
||||
self._input_queue.clear()
|
||||
self._output_queue.clear()
|
||||
|
||||
@property
|
||||
def input_queue(self) -> QueueHandle:
|
||||
return monkey_patch_queue_handle(self._input_queue)
|
||||
|
||||
@property
|
||||
def output_queue(self) -> QueueHandle:
|
||||
return monkey_patch_queue_handle(self._output_queue)
|
||||
@ -1,5 +0,0 @@
|
||||
from collections import deque
|
||||
|
||||
|
||||
class QueueMock(deque):
|
||||
pass
|
||||
@ -1,30 +0,0 @@
|
||||
from pyinfra.storage.adapters.adapter import StorageAdapter
|
||||
from test.storage.client_mock import StorageClientMock
|
||||
|
||||
|
||||
class StorageAdapterMock(StorageAdapter):
|
||||
def __init__(self, client: StorageClientMock):
|
||||
assert isinstance(client, StorageClientMock)
|
||||
super().__init__(client=client)
|
||||
self.__client = self._StorageAdapter__client
|
||||
|
||||
def make_bucket(self, bucket_name):
|
||||
self.__client.make_bucket(bucket_name)
|
||||
|
||||
def has_bucket(self, bucket_name):
|
||||
return self.__client.has_bucket(bucket_name)
|
||||
|
||||
def put_object(self, bucket_name, object_name, data):
|
||||
return self.__client.put_object(bucket_name, object_name, data)
|
||||
|
||||
def get_object(self, bucket_name, object_name):
|
||||
return self.__client.get_object(bucket_name, object_name)
|
||||
|
||||
def get_all_objects(self, bucket_name):
|
||||
return self.__client.get_all_objects(bucket_name)
|
||||
|
||||
def clear_bucket(self, bucket_name):
|
||||
return self.__client.clear_bucket(bucket_name)
|
||||
|
||||
def get_all_object_names(self, bucket_name):
|
||||
return self.__client.get_all_object_names(bucket_name)
|
||||
@ -1,27 +0,0 @@
|
||||
from itertools import repeat
|
||||
|
||||
|
||||
class StorageClientMock:
|
||||
def __init__(self):
|
||||
self.__data = {}
|
||||
|
||||
def make_bucket(self, bucket_name):
|
||||
self.__data[bucket_name] = {}
|
||||
|
||||
def has_bucket(self, bucket_name):
|
||||
return bucket_name in self.__data
|
||||
|
||||
def put_object(self, bucket_name, object_name, data):
|
||||
self.__data[bucket_name][object_name] = data
|
||||
|
||||
def get_object(self, bucket_name, object_name):
|
||||
return self.__data[bucket_name][object_name]
|
||||
|
||||
def get_all_objects(self, bucket_name):
|
||||
return self.__data[bucket_name].values()
|
||||
|
||||
def clear_bucket(self, bucket_name):
|
||||
self.__data[bucket_name] = {}
|
||||
|
||||
def get_all_object_names(self, bucket_name):
|
||||
return zip(repeat(bucket_name), self.__data[bucket_name])
|
||||
@ -1,10 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from pyinfra.storage.adapters.azure import AzureStorageAdapter
|
||||
from test.storage.client_mock import StorageClientMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter():
|
||||
adapter = AzureStorageAdapter(StorageClientMock())
|
||||
return adapter
|
||||
@ -1,45 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from pyinfra.config import Config, parse_disjunction_string
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_file_content():
|
||||
return {"A": [{"B": [1, 2]}, {"C": 3}, 4], "D": {"E": {"F": True}}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config(config_file_content):
|
||||
with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w") as f:
|
||||
yaml.dump(config_file_content, f, default_flow_style=False)
|
||||
yield Config(f.name)
|
||||
|
||||
|
||||
def test_dot_access_key_exists(config):
|
||||
assert config.A == [{"B": [1, 2]}, {"C": 3}, 4]
|
||||
assert config.D.E["F"]
|
||||
|
||||
|
||||
def test_access_key_exists(config):
|
||||
assert config["A"] == [{"B": [1, 2]}, {"C": 3}, 4]
|
||||
assert config["A"][0] == {"B": [1, 2]}
|
||||
assert config["A"][0]["B"] == [1, 2]
|
||||
assert config["A"][0]["B"][0] == 1
|
||||
|
||||
|
||||
def test_dot_access_key_does_not_exists(config):
|
||||
assert config.B is None
|
||||
|
||||
|
||||
def test_access_key_does_not_exists(config):
|
||||
assert config["B"] is None
|
||||
|
||||
|
||||
def test_parse_disjunction_string():
|
||||
assert parse_disjunction_string("A|Bb|c") == "c"
|
||||
os.environ["Bb"] = "d"
|
||||
assert parse_disjunction_string("A|Bb|c") == "d"
|
||||
@ -1,160 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pika
|
||||
import pytest
|
||||
import testcontainers.compose
|
||||
|
||||
from pyinfra.exceptions import UnknownClient
|
||||
from pyinfra.locations import TEST_DIR, COMPOSE_PATH
|
||||
from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager, get_connection_params
|
||||
from pyinfra.queue.queue_manager.queue_manager import QueueManager
|
||||
from pyinfra.storage.adapters.azure import AzureStorageAdapter
|
||||
from pyinfra.storage.adapters.s3 import S3StorageAdapter
|
||||
from pyinfra.storage.clients.azure import get_azure_client
|
||||
from pyinfra.storage.clients.s3 import get_s3_client
|
||||
from pyinfra.storage.storage import Storage
|
||||
from test.config import CONFIG
|
||||
from test.queue.queue_manager_mock import QueueManagerMock
|
||||
from test.storage.adapter_mock import StorageAdapterMock
|
||||
from test.storage.client_mock import StorageClientMock
|
||||
from pyinfra.visitor import StorageStrategy, ForwardingStrategy, QueueVisitor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def bucket_name():
|
||||
return "pyinfra-test-bucket"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_data():
|
||||
with open(f"{TEST_DIR}/test_data/test_data.TEXT.json", "r") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(storage_data):
|
||||
response = Mock(status_code=200)
|
||||
response.json.return_value = storage_data
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_payload():
|
||||
return json.dumps({"dossierId": "test", "fileId": "test"})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_make_load_data():
|
||||
def load_data(payload):
|
||||
return storage_data
|
||||
|
||||
return load_data
|
||||
|
||||
|
||||
@pytest.fixture(params=["minio", "aws"], scope="session")
|
||||
def storage(client_name, bucket_name, request):
|
||||
logger.debug("Setup for storage")
|
||||
storage = Storage(get_adapter(client_name, request.param))
|
||||
storage.make_bucket(bucket_name)
|
||||
storage.clear_bucket(bucket_name)
|
||||
yield storage
|
||||
logger.debug("Teardown for storage")
|
||||
storage.clear_bucket(bucket_name)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def docker_compose(sleep_seconds=30):
|
||||
logger.info(f"Starting docker containers with {COMPOSE_PATH}/docker-compose.yml...")
|
||||
compose = testcontainers.compose.DockerCompose(COMPOSE_PATH, compose_file_name="docker-compose.yml")
|
||||
compose.start()
|
||||
logger.info(f"Sleeping for {sleep_seconds} seconds to wait for containers to finish startup... ")
|
||||
time.sleep(sleep_seconds)
|
||||
yield compose
|
||||
compose.stop()
|
||||
|
||||
|
||||
def get_pika_connection_params():
|
||||
params = get_connection_params()
|
||||
return params
|
||||
|
||||
|
||||
def get_s3_params(s3_backend):
|
||||
params = CONFIG.storage[s3_backend]
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def get_adapter(client_name, s3_backend):
|
||||
if client_name == "mock":
|
||||
return StorageAdapterMock(StorageClientMock())
|
||||
if client_name == "azure":
|
||||
return AzureStorageAdapter(get_azure_client(CONFIG.storage.azure.connection_string))
|
||||
if client_name == "s3":
|
||||
return S3StorageAdapter(get_s3_client(get_s3_params(s3_backend)))
|
||||
else:
|
||||
raise UnknownClient(client_name)
|
||||
|
||||
|
||||
def get_queue_manager(queue_manager_name) -> QueueManager:
|
||||
if queue_manager_name == "mock":
|
||||
return QueueManagerMock("input", "output")
|
||||
if queue_manager_name == "pika":
|
||||
return PikaQueueManager("input", "output", connection_params=get_pika_connection_params())
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def queue_manager(queue_manager_name):
|
||||
def close_connections():
|
||||
if queue_manager_name == "pika":
|
||||
try:
|
||||
queue_manager.connection.close()
|
||||
except (pika.exceptions.StreamLostError, pika.exceptions.ConnectionWrongStateError, ConnectionResetError):
|
||||
logger.debug("Connection was already closed when attempting to close explicitly.")
|
||||
|
||||
def close_channel():
|
||||
if queue_manager_name == "pika":
|
||||
try:
|
||||
queue_manager.channel.close()
|
||||
except pika.exceptions.ChannelWrongStateError:
|
||||
logger.debug("Channel was already closed when attempting to close explicitly.")
|
||||
|
||||
queue_manager = get_queue_manager(queue_manager_name)
|
||||
yield queue_manager
|
||||
close_connections()
|
||||
close_channel()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def callback():
|
||||
def inner(request):
|
||||
return request["data"].decode() * 2
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def analysis_callback(callback):
|
||||
def inner(request):
|
||||
return callback(request)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def response_strategy(response_strategy_name, storage):
|
||||
if response_strategy_name == "storage":
|
||||
return StorageStrategy(storage)
|
||||
if response_strategy_name == "forwarding":
|
||||
return ForwardingStrategy()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def visitor(storage, analysis_callback, response_strategy):
|
||||
return QueueVisitor(storage, analysis_callback, response_strategy)
|
||||
@ -1,126 +0,0 @@
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
|
||||
import pytest
|
||||
|
||||
from pyinfra.exceptions import ProcessingFailure
|
||||
from pyinfra.queue.consumer import Consumer
|
||||
from pyinfra.visitor import get_object_descriptor, ForwardingStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def consumer(queue_manager, callback):
|
||||
return Consumer(callback, queue_manager)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def access_callback():
|
||||
return itemgetter("fileId")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def items():
|
||||
def inner():
|
||||
for i in range(3):
|
||||
body = {
|
||||
"dossierId": "folder",
|
||||
"fileId": f"file{i}",
|
||||
"targetFileExtension": "in.gz",
|
||||
"responseFileExtension": "out.gz",
|
||||
}
|
||||
yield f"{i}".encode(), body
|
||||
|
||||
return list(inner())
|
||||
|
||||
|
||||
class TestConsumer:
|
||||
@pytest.mark.parametrize("queue_manager_name", ["mock", "pika"], scope="session")
|
||||
def test_consuming_empty_input_queue_does_not_put_anything_on_output_queue(self, consumer, queue_manager):
|
||||
queue_manager.clear()
|
||||
consumer.consume()
|
||||
assert queue_manager.output_queue.empty()
|
||||
|
||||
@pytest.mark.parametrize("queue_manager_name", ["mock", "pika"], scope="session")
|
||||
def test_consuming_nonempty_input_queue_puts_messages_on_output_queue_in_fifo_order(
|
||||
self, consumer, queue_manager, callback
|
||||
):
|
||||
def produce_items():
|
||||
return map(str, range(3))
|
||||
|
||||
def mock_visitor(callback):
|
||||
def inner(data):
|
||||
return callback({"data": data.encode()})
|
||||
|
||||
return inner
|
||||
|
||||
callback = mock_visitor(callback)
|
||||
|
||||
queue_manager.clear()
|
||||
|
||||
for item in produce_items():
|
||||
queue_manager.publish_request(item)
|
||||
|
||||
requests = consumer.consume()
|
||||
|
||||
for _, r in zip(produce_items(), requests):
|
||||
queue_manager.publish_response(r, callback)
|
||||
|
||||
assert queue_manager.output_queue.to_list() == ["00", "11", "22"]
|
||||
|
||||
@pytest.mark.parametrize("queue_manager_name", ["mock", "pika"], scope="session")
|
||||
@pytest.mark.parametrize("client_name", ["mock", "s3", "azure"], scope="session")
|
||||
@pytest.mark.parametrize("response_strategy_name", ["forwarding", "storage"], scope="session")
|
||||
def test_consuming_nonempty_input_queue_with_visitor_puts_messages_on_output_queue_in_fifo_order(
|
||||
self, consumer, queue_manager, visitor, bucket_name, storage, items
|
||||
):
|
||||
|
||||
visitor.response_strategy = ForwardingStrategy()
|
||||
|
||||
queue_manager.clear()
|
||||
storage.clear_bucket(bucket_name)
|
||||
|
||||
for data, message in items:
|
||||
storage.put_object(**get_object_descriptor(message), data=gzip.compress(data))
|
||||
queue_manager.publish_request(message)
|
||||
|
||||
requests = consumer.consume(inactivity_timeout=5)
|
||||
|
||||
for itm, req in zip(items, requests):
|
||||
logger.debug(f"Processing item {itm}")
|
||||
queue_manager.publish_response(req, visitor)
|
||||
|
||||
assert list(map(itemgetter("data"), queue_manager.output_queue.to_list())) == ["00", "11", "22"]
|
||||
|
||||
@pytest.mark.parametrize("queue_manager_name", ["pika"], scope="session")
|
||||
def test_message_is_republished_when_callback_raises_processing_failure_exception(
|
||||
self, consumer, queue_manager, bucket_name, items
|
||||
):
|
||||
class DebugError(Exception):
|
||||
pass
|
||||
|
||||
def callback(_):
|
||||
raise ProcessingFailure()
|
||||
|
||||
def reject_patch(*args, **kwargs):
|
||||
raise DebugError()
|
||||
|
||||
queue_manager.reject = reject_patch
|
||||
|
||||
queue_manager.clear()
|
||||
|
||||
for data, message in items:
|
||||
queue_manager.publish_request(message)
|
||||
|
||||
requests = consumer.consume()
|
||||
|
||||
logger = logging.getLogger("pyinfra.queue.queue_manager.pika_queue_manager")
|
||||
logger.addFilter(lambda record: False)
|
||||
|
||||
with pytest.raises(DebugError):
|
||||
while True:
|
||||
queue_manager.publish_response(next(requests), callback)
|
||||
@ -1,38 +0,0 @@
|
||||
import gzip
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from pyinfra.visitor import get_object_descriptor, get_response_object_descriptor
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def body():
|
||||
return {"dossierId": "folder", "fileId": "file", "targetFileExtension": "in.gz", "responseFileExtension": "out.gz"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("client_name", ["mock", "azure", "s3"], scope="session")
|
||||
class TestVisitor:
|
||||
@pytest.mark.parametrize("response_strategy_name", ["forwarding", "storage"], scope="session")
|
||||
def test_given_a_input_queue_message_callback_pulls_the_data_from_storage(
|
||||
self, visitor, body, storage, bucket_name
|
||||
):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(**get_object_descriptor(body), data=gzip.compress(b"content"))
|
||||
data_received = visitor.load_data(body)
|
||||
assert b"content" == data_received
|
||||
|
||||
@pytest.mark.parametrize("response_strategy_name", ["forwarding", "storage"], scope="session")
|
||||
def test_visitor_pulls_and_processes_data(self, visitor, body, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(**get_object_descriptor(body), data=gzip.compress("2".encode()))
|
||||
response_body = visitor.load_and_process(body)
|
||||
assert response_body["data"] == "22"
|
||||
|
||||
@pytest.mark.parametrize("response_strategy_name", ["storage"], scope="session")
|
||||
def test_visitor_puts_response_on_storage(self, visitor, body, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(**get_object_descriptor(body), data=gzip.compress("2".encode()))
|
||||
response_body = visitor(body)
|
||||
assert "data" not in response_body
|
||||
assert json.loads(gzip.decompress(storage.get_object(**get_response_object_descriptor(body))))["data"] == "22"
|
||||
@ -1,52 +0,0 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from pyinfra.exceptions import DataLoadingFailure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("client_name", ["mock", "azure", "s3"], scope="session")
|
||||
class TestStorage:
|
||||
def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
data_received = storage.get_all_objects(bucket_name)
|
||||
assert not {*data_received}
|
||||
|
||||
def test_getting_object_put_in_bucket_is_object(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(bucket_name, "file", b"content")
|
||||
data_received = storage.get_object(bucket_name, "file")
|
||||
assert b"content" == data_received
|
||||
|
||||
def test_getting_nested_object_put_in_bucket_is_nested_object(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(bucket_name, "folder/file", b"content")
|
||||
data_received = storage.get_object(bucket_name, "folder/file")
|
||||
assert b"content" == data_received
|
||||
|
||||
def test_getting_objects_put_in_bucket_are_objects(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(bucket_name, "file1", b"content 1")
|
||||
storage.put_object(bucket_name, "folder/file2", b"content 2")
|
||||
data_received = storage.get_all_objects(bucket_name)
|
||||
assert {b"content 1", b"content 2"} == {*data_received}
|
||||
|
||||
def test_make_bucket_produces_bucket(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.make_bucket(bucket_name)
|
||||
assert storage.has_bucket(bucket_name)
|
||||
|
||||
def test_listing_bucket_files_yields_all_files_in_bucket(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
storage.put_object(bucket_name, "file1", b"content 1")
|
||||
storage.put_object(bucket_name, "file2", b"content 2")
|
||||
full_names_received = storage.get_all_object_names(bucket_name)
|
||||
assert {(bucket_name, "file1"), (bucket_name, "file2")} == {*full_names_received}
|
||||
|
||||
def test_data_loading_failure_raised_if_object_not_present(self, storage, bucket_name):
|
||||
storage.clear_bucket(bucket_name)
|
||||
with pytest.raises(DataLoadingFailure):
|
||||
storage.get_object(bucket_name, "folder/file")
|
||||
Loading…
x
Reference in New Issue
Block a user