Pull request #20: Pika encapsulation

Merge in RR/pyinfra from pika_encapsulation to master

Squashed commit of the following:

commit 251fea4094062a72f5f0b1f8a54f959a1d7309ec
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date:   Tue Mar 15 15:02:25 2022 +0100

    Adjusted readme and config to the actual changes.

commit d00a0aadc02be8ab1343dbaa2a8df82418e77673
Merge: ded29bb c34a1ce
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date:   Tue Mar 15 14:06:21 2022 +0100

    Merge branch 'master' of ssh://git.iqser.com:2222/rr/pyinfra into pika_encapsulation

commit ded29bb8d0a78c1d5fd172edb74f0b120da05d5f
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date:   Tue Mar 15 14:02:37 2022 +0100

    blackkkkkkkkkkkkkkkk

commit de3899c69f4e83fa5b8dc2702a18be66661201e9
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date:   Tue Mar 15 13:59:43 2022 +0100

    Judgement Day - The Great Refactoring

commit 48eb2bb792dd45539e7db20949fedae03316c2b3
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date:   Tue Mar 15 13:30:41 2022 +0100

    blacky

commit 38fb073291c6989fcc5560ef9f97e7dafd9570be
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date:   Tue Mar 15 13:29:46 2022 +0100

    update mock client/publish

commit 8744cdeb6d21adab81fb82c7850a4ada31a1a3f9
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date:   Tue Mar 15 13:29:19 2022 +0100

    quickfix consumer.consume_and_publish bug

commit 0ce30cafbf15142e3f773d0e510b7d22927ee86d
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Tue Mar 15 11:42:04 2022 +0100

    updated serving logic for new queue, visior and storage logic

commit 5ac5fbf50bf9d8a6b6466cd04d82d21525250b4c
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Tue Mar 15 11:41:18 2022 +0100

    changed callback signature to expect a dict with a key 'data'

commit 4c7c4b466e2b5228a6098c9130cd6c3334315153
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Tue Mar 15 09:46:39 2022 +0100

    set dead letter queue related params

commit f8b60aad16eb0f5c91bed857ac3b352c4bcb7468
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Mar 14 20:22:57 2022 +0100

    added republishing logic

commit a40f07139af9121d8b72efa52b6b2588e7d233e8
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Mar 14 19:10:08 2022 +0100

    removed obsolete import

commit d3605b0913f8b155961b8c8e90e7c1d442a1da4f
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Mar 14 19:09:50 2022 +0100

    removed obsolete code

commit 9130b9fc753c63bfe4abd3cc26b31caf5c5bf2cb
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Mar 14 19:05:59 2022 +0100

    test refac

commit 1205adecdbd42a46dcc264a46086aab94c285b94
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Mar 14 14:49:59 2022 +0100

    removed obsolete import

commit 59d0de97e39646b3fa6c6ca91644dbf8adbd0de3
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Mar 14 14:49:21 2022 +0100

    applied black

commit f78a6894bc99b42d53a7c265ffd8c7869e9d606d
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Mar 14 14:49:03 2022 +0100

    applied black

commit 32740a914f8be15eed954389f6ddbb565e1567e6
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Mar 14 14:47:47 2022 +0100

    removed obsolete code

commit f21cbc141efc10d0e779f683dcc452e6c889dd6b
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Mar 14 14:47:27 2022 +0100

    consumer with visitor test fixed (added connection.close() in case of pika queue manager)

commit c415dcd8e6375e015b959a9c3cdee61b8ffe7d60
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Mon Mar 14 13:56:54 2022 +0100

    consumer with visitor test WIP 2

... and 15 more commits
This commit is contained in:
Julius Unverfehrt 2022-03-15 15:05:14 +01:00
parent c34a1ce849
commit f5b7203778
29 changed files with 741 additions and 523410 deletions

186
README.md
View File

@ -6,154 +6,100 @@ The Infrastructure expects to be deployed in the same Pod / local environment as
A configuration is located in `/config.yaml`. All relevant variables can be configured via exporting environment variables.
| Environment Variable | Default | Description |
|-------------------------------|--------------------------------|--------------------------------------------------------------------------------------------------|
| _service_ | | |
| LOGGING_LEVEL_ROOT | DEBUG | Logging level for service logger |
| RESPONSE_TYPE | "stream" | Whether the analysis response is stored as file on storage or sent as stream: "file" or "stream" |
| RESPONSE_FILE_EXTENSION | ".NER_ENTITIES.json.gz" | Extension to the file that stores the analyzed response on storage |
| _probing_webserver_ | | |
| PROBING_WEBSERVER_HOST | "0.0.0.0" | Probe webserver address |
| PROBING_WEBSERVER_PORT | 8080 | Probe webserver port |
| PROBING_WEBSERVER_MODE | production | Webserver mode: {development, production} |
| _rabbitmq_ | | |
| RABBITMQ_HOST | localhost | RabbitMQ host address |
| RABBITMQ_PORT | 5672 | RabbitMQ host port |
| RABBITMQ_USERNAME | user | RabbitMQ username |
| RABBITMQ_PASSWORD | bitnami | RabbitMQ password |
| RABBITMQ_HEARTBEAT | 7200 | Controls AMQP heartbeat timeout in seconds |
| _queues_ | | |
| REQUEST_QUEUE | request_queue | Requests to service |
| RESPONSE_QUEUE | response_queue | Responses by service |
| DEAD_LETTER_QUEUE | dead_letter_queue | Messages that failed to process |
| _callback_ | | |
| RETRY | False | Toggles retry behaviour |
| MAX_ATTEMPTS | 3 | Number of times a message may fail before being published to dead letter queue |
| ANALYSIS_ENDPOINT | "http://127.0.0.1:5000" | |
| _storage_ | | |
| STORAGE_BACKEND | s3 | The type of storage to use {s3, azure} |
| STORAGE_BUCKET | "pyinfra-test-bucket" | The bucket / container to pull files specified in queue requests from |
| TARGET_FILE_EXTENSION | ".TEXT.json.gz" | Defines type of file to pull from storage: .TEXT.json.gz or .ORIGIN.pdf.gz |
| STORAGE_ENDPOINT | "http://127.0.0.1:9000" | |
| STORAGE_KEY | | |
| STORAGE_SECRET | | |
| STORAGE_AZURECONNECTIONSTRING | "DefaultEndpointsProtocol=..." | |
| Environment Variable | Default | Description |
|-------------------------------|--------------------------------|-----------------------------------------------------------------------|
| LOGGING_LEVEL_ROOT | DEBUG | Logging level for service logger |
| PROBING_WEBSERVER_HOST | "0.0.0.0" | Probe webserver address |
| PROBING_WEBSERVER_PORT | 8080 | Probe webserver port |
| PROBING_WEBSERVER_MODE | production | Webserver mode: {development, production} |
| RABBITMQ_HOST | localhost | RabbitMQ host address |
| RABBITMQ_PORT | 5672 | RabbitMQ host port |
| RABBITMQ_USERNAME | user | RabbitMQ username |
| RABBITMQ_PASSWORD | bitnami | RabbitMQ password |
| RABBITMQ_HEARTBEAT | 7200 | Controls AMQP heartbeat timeout in seconds |
| REQUEST_QUEUE | request_queue | Requests to service |
| RESPONSE_QUEUE | response_queue | Responses by service |
| DEAD_LETTER_QUEUE | dead_letter_queue | Messages that failed to process |
| ANALYSIS_ENDPOINT | "http://127.0.0.1:5000" | Endpoint for analysis container |
| STORAGE_BACKEND | s3 | The type of storage to use {s3, azure} |
| STORAGE_BUCKET | "pyinfra-test-bucket" | The bucket / container to pull files specified in queue requests from |
| STORAGE_ENDPOINT | "http://127.0.0.1:9000" | Endpoint for s3 storage |
| STORAGE_KEY | root | User for s3 storage |
| STORAGE_SECRET | password | Password for s3 storage |
| STORAGE_AZURECONNECTIONSTRING | "DefaultEndpointsProtocol=..." | Connection string for Azure storage |
## Response Format
### RESPONSE_AS_FILE == False
Response-Format:
### Expected AMQP input message:
```json
{
"dossierId": "klaus",
"fileId": "1a7fd8ac0da7656a487b68f89188be82",
"imageMetadata": ANALYSIS_DATA
"dossierId": "",
"fileId": "",
"targetFileExtension": "",
"responseFileExtension": ""
}
```
Response-example for image-prediction
Optionally, the input message can contain a field with the key `"operations"`.
### AMQP output message:
```json
{
"dossierId": "klaus",
"fileId": "1a7fd8ac0da7656a487b68f89188be82",
"imageMetadata": [
{
"classification": {
"label": "logo",
"probabilities": {
"formula": 0.0,
"logo": 1.0,
"other": 0.0,
"signature": 0.0
}
},
"filters": {
"allPassed": true,
"geometry": {
"imageFormat": {
"quotient": 1.570791527313267,
"tooTall": false,
"tooWide": false
},
"imageSize": {
"quotient": 0.19059804229011604,
"tooLarge": false,
"tooSmall": false
}
},
"probability": {
"unconfident": false
}
},
"geometry": {
"height": 107.63999999999999,
"width": 169.08000000000004
},
"position": {
"pageNumber": 1,
"x1": 213.12,
"x2": 382.20000000000005,
"y1": 568.7604,
"y2": 676.4004
}
}
]
"dossierId": "",
"fileId": "",
...
}
```
### RESPONSE_AS_FILE == True
Creates a respone file on the request storage, named `dossier_Id / file_Id + RESPONSE_FILE_EXTENSION` with the `ANALYSIS_DATA` as content.
## Development
### Local Setup
Either run `src/serve.py` or the built Docker image.
You can run the infrastructure either as module via. `src/serve.py` or as Dockercontainer simulating the kubernetes environment
### Setup
1. Install module / build docker image
Install module.
```bash
pip install -e .
pip install -r requirements.txt
```
```bash
pip install -e .
pip install -r requirements.txt
```
```bash
docker build -f Dockerfile -t pyinfra .
```
or build docker image.
2. Run rabbitmq & minio
```bash
docker build -f Dockerfile -t pyinfra .
```
```bash
docker-compose up
```
### Usage
3. Run module
**Shell 1:** Start a MinIO and a RabbitMQ docker container.
```bash
python src/serve.py
```
OR as container:
```bash
docker run --net=host pyinfra
```
Start your prediction container for example ner-prediction or image-prediction (follow their corresponding README for
building the container).
```bash
docker-compose up
```
To put a file on the queue do:
**Shell 2:** Add files to the local minio storage.
python src/manage_minio.py add --file path/to/file dossierID
```bash
python scripts/manage_minio.py add <MinIO target folder> -d path/to/a/folder/with/PDFs
```
To start mock:
**Shell 2:** Run pyinfra-server.
python src/mock_client.py
```bash
python src/serve.py
```
or as container:
### Hints:
When stopping the docker-compose up, use docker-compose down to remove containers created by up.
```bash
docker run --net=host pyinfra
```
If uploaded files are stuck, clean the minio storage by using ```python src/manage_minio.py purge``` or delete local
minio data folder in pyinfra with ```sudo rm -rf data```
**Shell 3:** Run analysis-container.
**Shell 4:** Start a client that sends requests to process PDFs from the MinIO store and annotates these PDFs according to the service responses.
```bash
python scripts/mock_client.py
```

View File

@ -1,9 +1,5 @@
service:
logging_level: $LOGGING_LEVEL_ROOT|DEBUG # Logging level for service logger
response:
type: $RESPONSE_TYPE|"file" # Whether the analysis response is stored as file on storage or sent as stream
extension: $RESPONSE_FILE_EXTENSION|".IMAGE_INFO.json.gz" # {.IMAGE_INFO.json.gz | .NER_ENTITIES.json.gz}
key: $RESPONSE_KEY|"imageMetadata" # the key of the result {result, imageMetadata}
probing_webserver:
host: $PROBING_WEBSERVER_HOST|"0.0.0.0" # Probe webserver address
@ -22,22 +18,12 @@ rabbitmq:
output: $RESPONSE_QUEUE|response_queue # Responses by service
dead_letter: $DEAD_LETTER_QUEUE|dead_letter_queue # Messages that failed to process
prefetch_count: 1
callback:
retry: # Controls retry behaviour for messages the processing of which failed
# TODO: check if this actually works
enabled: $RETRY|False # Toggles retry behaviour
max_attempts: $MAX_ATTEMPTS|3 # Number of times a message may fail before being published to dead letter queue
analysis_endpoint: $ANALYSIS_ENDPOINT|"http://127.0.0.1:5000"
storage:
backend: $STORAGE_BACKEND|s3 # The type of storage to use {s3, azure}
bucket: $STORAGE_BUCKET|"pyinfra-test-bucket" # The bucket / container to pull files specified in queue requests from
target_file_extension: $TARGET_FILE_EXTENSION|".ORIGIN.pdf.gz" # {.TEXT.json.gz | .ORIGIN.pdf.gz} Defines type of file to pull from storage
s3:
endpoint: $STORAGE_ENDPOINT|"http://127.0.0.1:9000"

View File

@ -1,96 +0,0 @@
import json
import logging
import tempfile
from time import sleep
from pyinfra.config import CONFIG
from pyinfra.exceptions import AnalysisFailure, DataLoadingFailure
from pyinfra.rabbitmq import make_connection, make_channel, declare_queue
from pyinfra.storage.storages import get_storage
from pyinfra.utils.file import upload_compressed_response
def make_retry_callback(republish, max_attempts):
def get_n_previous_attempts(props):
return 0 if props.headers is None else props.headers.get("x-retry-count", 0)
def attempts_remain(n_attempts):
return n_attempts < max_attempts
def callback(channel, method, properties, body):
n_attempts = get_n_previous_attempts(properties) + 1
logging.error(f"Message failed to process {n_attempts}/{max_attempts} times: {body}")
if attempts_remain(n_attempts):
republish(channel, body, n_attempts)
channel.basic_ack(delivery_tag=method.delivery_tag)
else:
logging.exception(f"Adding to dead letter queue: {body}")
channel.basic_reject(delivery_tag=method.delivery_tag, requeue=False)
return callback
def wrap_callback_in_retry_logic(callback, retry_callback):
def wrapped_callback(channel, method, properties, body):
try:
callback(channel, method, properties, body)
except (AnalysisFailure, DataLoadingFailure):
sleep(5)
retry_callback(channel, method, properties, body)
return wrapped_callback
def json_wrap(body_processor):
def inner(payload):
return json.dumps(body_processor(json.loads(payload)))
return inner
def make_callback_for_output_queue(json_wrapped_body_processor, output_queue_name):
connection = make_connection()
channel = make_channel(connection)
declare_queue(channel, output_queue_name)
def callback(channel, method, _, body):
"""
response is dossier_id, file_id and analysis result, if CONFIG.service.response.type == "file" the response only
contains file_id and dossier_id and the analysis result will be written in a json file
Args:
channel:
method:
_:
body:
Returns:
"""
dossier_id, file_id, result = json_wrapped_body_processor(body)
result_key = CONFIG.service.response.key if CONFIG.service.response.key else "result"
result = json.dumps({"dossierId": dossier_id, "fileId": file_id, result_key: result})
if CONFIG.service.response.type == "file":
upload_compressed_response(
get_storage(CONFIG.storage.backend), CONFIG.storage.bucket, dossier_id, file_id, result
)
result = json.dumps({"dossierId": dossier_id, "fileId": file_id})
channel.basic_publish(exchange="", routing_key=output_queue_name, body=result)
channel.basic_ack(delivery_tag=method.delivery_tag)
return callback
def make_retry_callback_for_output_queue(json_wrapped_body_processor, output_queue_name, retry_callback):
callback = make_callback_for_output_queue(json_wrapped_body_processor, output_queue_name)
callback = wrap_callback_in_retry_logic(callback, retry_callback)
return callback

View File

@ -1,11 +1,23 @@
"""Implements a config object with dot-indexing syntax."""
from envyaml import EnvYAML
from pyinfra.locations import CONFIG_FILE
def make_art():
return """
______ _____ __
| ___ \ |_ _| / _|
| |_/ / _ | | _ __ | |_ _ __ __ _
| __/ | | || || '_ \| _| '__/ _` |
| | | |_| || || | | | | | | | (_| |
\_| \__, \___/_| |_|_| |_| \__,_|
__/ |
|___/
"""
def _get_item_and_maybe_make_dotindexable(container, item):
ret = container[item]
return DotIndexable(ret) if isinstance(ret, dict) else ret

View File

@ -1,43 +0,0 @@
import logging
from typing import Callable
import pika
from retry import retry
from pyinfra.exceptions import ProcessingFailure
from pyinfra.rabbitmq import make_connection, make_channel, declare_queue
class ConsumerError(Exception):
pass
@retry(ConsumerError, tries=3, delay=5, jitter=(1, 3))
def consume(queue_name: str, on_message_callback: Callable):
connection = make_connection()
channel = make_channel(connection)
declare_queue(channel, queue_name)
logging.info("Started infrastructure.")
while True:
try:
logging.info("Waiting for messages...")
channel.basic_consume(queue=queue_name, auto_ack=False, on_message_callback=on_message_callback)
channel.start_consuming()
except pika.exceptions.ConnectionClosedByBroker as err:
logging.critical(f"Caught a channel error: {err}, stopping.")
break
except pika.exceptions.AMQPChannelError as err:
logging.critical(f"Caught a channel error: {err}, stopping.")
break
except pika.exceptions.AMQPConnectionError as err:
logging.info("No AMPQ-connection found, retrying...")
logging.debug(err)
continue
except ProcessingFailure as err:
raise ConsumerError(f"Error while consuming {queue_name}.") from err

View File

@ -1,70 +0,0 @@
import gzip
import json
import logging
from operator import itemgetter
import requests
from pyinfra.config import CONFIG
from pyinfra.exceptions import DataLoadingFailure, AnalysisFailure, ProcessingFailure
from pyinfra.utils.file import combine_dossier_id_and_file_id_and_extension
def make_storage_data_loader(storage, bucket_name):
def get_object_name(payload: dict) -> str:
dossier_id, file_id = itemgetter("dossierId", "fileId")(payload)
object_name = combine_dossier_id_and_file_id_and_extension(
dossier_id, file_id, CONFIG.storage.target_file_extension
)
return object_name
def download(payload):
object_name = get_object_name(payload)
logging.debug(f"Downloading {object_name}...")
data = storage.get_object(bucket_name, object_name)
logging.debug(f"Downloaded {object_name}.")
return data
def decompress(data):
return gzip.decompress(data)
def load_data(payload):
try:
return decompress(download(payload))
except Exception as err:
logging.warning(f"Loading data from storage failed for {payload}.")
raise DataLoadingFailure() from err
return load_data
def make_analyzer(analysis_endpoint):
def analyze(data):
try:
logging.debug(f"Requesting analysis from {analysis_endpoint}...")
analysis_response = requests.post(analysis_endpoint, data=data)
analysis_response.raise_for_status()
analysis_response = analysis_response.json()
logging.debug(f"Received response.")
return analysis_response
except Exception as err:
logging.warning("Exception caught when calling analysis endpoint.")
raise AnalysisFailure() from err
return analyze
def make_payload_processor(load_data, analyze_file):
def process(payload: dict):
logging.info(f"Processing {payload}...")
try:
payload = json.loads(payload)
dossier_id, file_id = itemgetter("dossierId", "fileId")(payload)
data = load_data(payload)
predictions = analyze_file(data)
return dossier_id, file_id, predictions
except (DataLoadingFailure, AnalysisFailure) as err:
logging.warning(f"Processing of {payload} failed.")
raise ProcessingFailure() from err
return process

View File

@ -20,3 +20,7 @@ class InvalidEndpoint(ValueError):
class UnknownClient(ValueError):
pass
class ConsumerError(Exception):
pass

View File

@ -1,11 +1,12 @@
import requests
from flask import Flask, jsonify
from retry import retry
from waitress import serve
from pyinfra.config import CONFIG
def run_probing_webserver(app, host=None, port=None, mode=None):
if not host:
host = CONFIG.probing_webserver.host
@ -23,7 +24,6 @@ def run_probing_webserver(app, host=None, port=None, mode=None):
def set_up_probing_webserver():
app = Flask(__name__)
@app.route("/ready", methods=["GET"])
@ -38,4 +38,15 @@ def set_up_probing_webserver():
resp.status_code = 200
return resp
@app.route("/prometheus", methods=["GET"])
def get_analysis_prometheus_endpoint():
@retry(requests.exceptions.ConnectionError, tries=3, delay=5, jitter=(1, 3))
def inner():
prom_endpoint = f"{CONFIG.rabbitmq.callback.analysis_endpoint}/prometheus"
metric = requests.get(prom_endpoint)
metric.raise_for_status()
return metric.text
return inner()
return app

13
pyinfra/queue/consumer.py Normal file
View File

@ -0,0 +1,13 @@
from pyinfra.queue.queue_manager.queue_manager import QueueManager
class Consumer:
def __init__(self, callback, queue_manager: QueueManager):
self.queue_manager = queue_manager
self.callback = callback
def consume_and_publish(self):
self.queue_manager.consume_and_publish(self.callback)
def consume(self):
yield from self.queue_manager.consume()

View File

View File

@ -0,0 +1,142 @@
import json
import logging
import pika
from pyinfra.config import CONFIG
from pyinfra.exceptions import ProcessingFailure
from pyinfra.queue.queue_manager.queue_manager import QueueHandle, QueueManager
logger = logging.getLogger("pika")
logger.setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(CONFIG.service.logging_level)
def monkey_patch_queue_handle(channel, queue) -> QueueHandle:
empty_message = (None, None, None)
def is_empty_message(message):
return message == empty_message
queue_handle = QueueHandle()
queue_handle.empty = lambda: is_empty_message(channel.basic_get(queue))
def produce_items():
while True:
message = channel.basic_get(queue)
if is_empty_message(message):
break
method_frame, properties, body = message
channel.basic_ack(method_frame.delivery_tag)
yield json.loads(body)
queue_handle.to_list = lambda: list(produce_items())
return queue_handle
def get_connection():
credentials = pika.PlainCredentials(username=CONFIG.rabbitmq.user, password=CONFIG.rabbitmq.password)
kwargs = {"host": CONFIG.rabbitmq.host, "port": CONFIG.rabbitmq.port, "credentials": credentials}
parameters = pika.ConnectionParameters(**kwargs)
connection = pika.BlockingConnection(parameters=parameters)
return connection
def get_n_previous_attempts(props):
return 0 if props.headers is None else props.headers.get("x-retry-count", 0)
def attempts_remain(n_attempts, max_attempts):
return n_attempts < max_attempts
class PikaQueueManager(QueueManager):
def __init__(self, input_queue, output_queue, dead_letter_queue=None):
super().__init__(input_queue, output_queue)
self.connection = get_connection()
self.channel = self.connection.channel()
if not dead_letter_queue:
dead_letter_queue = CONFIG.rabbitmq.queues.dead_letter
args = {"x-dead-letter-exchange": "", "x-dead-letter-routing-key": dead_letter_queue}
self.channel.queue_declare(input_queue, arguments=args, auto_delete=False)
self.channel.queue_declare(output_queue, arguments=args, auto_delete=False)
def republish(self, body, n_current_attempts, frame):
self.channel.basic_publish(
exchange="",
routing_key=self._input_queue,
body=body,
properties=pika.BasicProperties(headers={"x-retry-count": n_current_attempts}),
)
self.channel.basic_ack(delivery_tag=frame.delivery_tag)
def publish_request(self, request):
logger.debug(f"Publishing {request}")
self.channel.basic_publish("", self._input_queue, json.dumps(request).encode())
def reject(self, body, frame):
logger.exception(f"Adding to dead letter queue: {body}")
self.channel.basic_reject(delivery_tag=frame.delivery_tag, requeue=False)
def publish_response(self, message, callback, max_attempts=3):
logger.debug(f"Publishing response for {message}.")
frame, properties, body = message
n_attempts = get_n_previous_attempts(properties) + 1
try:
response = json.dumps(callback(json.loads(body)))
self.channel.basic_publish("", self._output_queue, response.encode())
self.channel.basic_ack(frame.delivery_tag)
except ProcessingFailure:
logger.error(f"Message failed to process {n_attempts}/{max_attempts} times: {body}")
if attempts_remain(n_attempts, max_attempts):
self.republish(body, n_attempts, frame)
else:
self.reject(body, frame)
def pull_request(self):
return self.channel.basic_get(self._input_queue)
def consume(self):
logger.debug("Consuming")
return self.channel.consume(self._input_queue)
def consume_and_publish(self, visitor):
logger.info(f"Consuming with callback {visitor.callback.__name__}")
for message in self.consume():
self.publish_response(message, visitor)
def clear(self):
self.channel.queue_purge(self._input_queue)
self.channel.queue_purge(self._output_queue)
@property
def input_queue(self) -> QueueHandle:
return monkey_patch_queue_handle(self.channel, self._input_queue)
@property
def output_queue(self) -> QueueHandle:
return monkey_patch_queue_handle(self.channel, self._output_queue)

View File

@ -0,0 +1,47 @@
import abc
class QueueHandle:
def empty(self) -> bool:
raise NotImplemented()
def to_list(self) -> list:
raise NotImplemented()
class QueueManager(abc.ABC):
def __init__(self, input_queue, output_queue):
self._input_queue = input_queue
self._output_queue = output_queue
@abc.abstractmethod
def publish_request(self, request):
pass
@abc.abstractmethod
def publish_response(self, response, callback):
pass
@abc.abstractmethod
def pull_request(self):
pass
@abc.abstractmethod
def consume(self):
pass
@abc.abstractmethod
def clear(self):
pass
@abc.abstractmethod
def input_queue(self) -> QueueHandle:
pass
@abc.abstractmethod
def output_queue(self) -> QueueHandle:
pass
@abc.abstractmethod
def consume_and_publish(self, callback):
pass

View File

@ -1,31 +0,0 @@
import pika
from pyinfra.config import CONFIG
def make_channel(connection) -> pika.adapters.blocking_connection.BlockingChannel:
channel = connection.channel()
channel.basic_qos(prefetch_count=CONFIG.rabbitmq.prefetch_count)
return channel
def declare_queue(channel, queue: str):
args = {"x-dead-letter-exchange": "", "x-dead-letter-routing-key": CONFIG.rabbitmq.queues.dead_letter}
return channel.queue_declare(queue=queue, auto_delete=False, arguments=args, durable=True)
def read_connection_params():
credentials = pika.PlainCredentials(CONFIG.rabbitmq.user, CONFIG.rabbitmq.password)
parameters = pika.ConnectionParameters(
host=CONFIG.rabbitmq.host,
port=CONFIG.rabbitmq.port,
heartbeat=CONFIG.rabbitmq.heartbeat,
credentials=credentials,
)
return parameters
def make_connection() -> pika.BlockingConnection:
parameters = read_connection_params()
connection = pika.BlockingConnection(parameters)
return connection

View File

View File

@ -0,0 +1,43 @@
from pyinfra.queue.queue_manager.queue_manager import QueueManager, QueueHandle
from pyinfra.test.queue.queue_mock import QueueMock
def monkey_patch_queue_handle(queue) -> QueueHandle:
queue_handle = QueueHandle()
queue_handle.empty = lambda: not queue
queue_handle.to_list = lambda: list(queue)
return queue_handle
class QueueManagerMock(QueueManager):
def __init__(self, input_queue, output_queue):
super().__init__(QueueMock(), QueueMock())
def publish_request(self, request):
self._input_queue.append(request)
def publish_response(self, message, callback):
self._output_queue.append(callback(message))
def pull_request(self):
return self._input_queue.popleft()
def consume(self):
while self._input_queue:
yield self.pull_request()
def consume_and_publish(self, callback):
for message in self.consume():
self.publish_response(message, callback)
def clear(self):
self._input_queue.clear()
self._output_queue.clear()
@property
def input_queue(self) -> QueueHandle:
return monkey_patch_queue_handle(self._input_queue)
@property
def output_queue(self) -> QueueHandle:
return monkey_patch_queue_handle(self._output_queue)

View File

@ -0,0 +1,5 @@
from collections import deque
class QueueMock(deque):
pass

File diff suppressed because it is too large Load Diff

View File

@ -6,20 +6,24 @@ import pytest
from pyinfra.exceptions import UnknownClient
from pyinfra.locations import TEST_DIR
from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager
from pyinfra.queue.queue_manager.queue_manager import QueueManager
from pyinfra.storage.adapters.azure import AzureStorageAdapter
from pyinfra.storage.adapters.s3 import S3StorageAdapter
from pyinfra.storage.clients.azure import get_azure_client
from pyinfra.storage.clients.s3 import get_s3_client
from pyinfra.storage.storage import Storage
from pyinfra.test.config import CONFIG
from pyinfra.test.queue.queue_manager_mock import QueueManagerMock
from pyinfra.test.storage.adapter_mock import StorageAdapterMock
from pyinfra.test.storage.client_mock import StorageClientMock
from pyinfra.visitor import StorageStrategy, ForwardingStrategy, QueueVisitor
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@pytest.fixture
@pytest.fixture(scope="session")
def bucket_name():
return "pyinfra-test-bucket"
@ -51,7 +55,7 @@ def mock_make_load_data():
return load_data
@pytest.fixture(params=["minio", "aws"])
@pytest.fixture(params=["minio", "aws"], scope="session")
def storage(client_name, bucket_name, request):
logger.debug("Setup for storage")
storage = Storage(get_adapter(client_name, request.param))
@ -71,3 +75,47 @@ def get_adapter(client_name, s3_backend):
return S3StorageAdapter(get_s3_client(CONFIG.storage[s3_backend]))
else:
raise UnknownClient(client_name)
def get_queue_manager(queue_manager_name) -> QueueManager:
if queue_manager_name == "mock":
return QueueManagerMock("input", "output")
if queue_manager_name == "pika":
return PikaQueueManager("input", "output")
@pytest.fixture(scope="session")
def queue_manager(queue_manager_name):
queue_manager = get_queue_manager(queue_manager_name)
yield queue_manager
if queue_manager_name == "pika":
queue_manager.connection.close()
@pytest.fixture(scope="session")
def callback():
def inner(request):
return request["data"].decode() * 2
return inner
@pytest.fixture
def analysis_callback(callback):
def inner(request):
return callback(request)
return inner
@pytest.fixture
def response_strategy(response_strategy_name, storage):
if response_strategy_name == "storage":
return StorageStrategy(storage)
if response_strategy_name == "forwarding":
return ForwardingStrategy()
@pytest.fixture()
def visitor(storage, analysis_callback, response_strategy):
return QueueVisitor(storage, analysis_callback, response_strategy)

View File

@ -0,0 +1,122 @@
import gzip
import json
import logging
from operator import itemgetter
import pytest
from pyinfra.exceptions import ProcessingFailure
from pyinfra.queue.consumer import Consumer
from pyinfra.visitor import get_object_descriptor, ForwardingStrategy
@pytest.fixture(scope="session")
def consumer(queue_manager, callback):
return Consumer(callback, queue_manager)
@pytest.fixture(scope="session")
def access_callback():
return itemgetter("fileId")
@pytest.fixture()
def items():
def inner():
for i in range(3):
body = {
"dossierId": "folder",
"fileId": f"file{i}",
"targetFileExtension": "in.gz",
"responseFileExtension": "out.gz",
}
yield f"{i}".encode(), body
return list(inner())
class TestConsumer:
@pytest.mark.parametrize("queue_manager_name", ["mock", "pika"], scope="session")
def test_consuming_empty_input_queue_does_not_put_anything_on_output_queue(self, consumer, queue_manager):
queue_manager.clear()
consumer.consume()
assert queue_manager.output_queue.empty()
@pytest.mark.parametrize("queue_manager_name", ["mock", "pika"], scope="session")
def test_consuming_nonempty_input_queue_puts_messages_on_output_queue_in_fifo_order(
self, consumer, queue_manager, callback
):
def produce_items():
return map(str, range(3))
def mock_visitor(callback):
def inner(data):
return callback({"data": data.encode()})
return inner
callback = mock_visitor(callback)
queue_manager.clear()
for item in produce_items():
queue_manager.publish_request(item)
requests = consumer.consume()
for _, r in zip(produce_items(), requests):
queue_manager.publish_response(r, callback)
assert queue_manager.output_queue.to_list() == ["00", "11", "22"]
@pytest.mark.parametrize("queue_manager_name", ["mock", "pika"], scope="session")
@pytest.mark.parametrize("client_name", ["mock", "s3", "azure"], scope="session")
@pytest.mark.parametrize("response_strategy_name", ["forwarding", "storage"], scope="session")
def test_consuming_nonempty_input_queue_with_visitor_puts_messages_on_output_queue_in_fifo_order(
self, consumer, queue_manager, visitor, bucket_name, storage, items
):
visitor.response_strategy = ForwardingStrategy()
queue_manager.clear()
storage.clear_bucket(bucket_name)
for data, message in items:
storage.put_object(**get_object_descriptor(message), data=gzip.compress(data))
queue_manager.publish_request(message)
requests = consumer.consume()
for _, r in zip(items, requests):
queue_manager.publish_response(r, visitor)
assert list(map(itemgetter("data"), queue_manager.output_queue.to_list())) == ["00", "11", "22"]
@pytest.mark.parametrize("queue_manager_name", ["pika"], scope="session")
def test_message_is_republished_when_callback_raises_processing_failure_exception(
self, consumer, queue_manager, bucket_name, items
):
class DebugError(Exception):
pass
def callback(_):
raise ProcessingFailure()
def reject_patch(*args, **kwargs):
raise DebugError()
queue_manager.reject = reject_patch
queue_manager.clear()
for data, message in items:
queue_manager.publish_request(message)
requests = consumer.consume()
logger = logging.getLogger("pyinfra.queue.queue_manager.pika_queue_manager")
logger.addFilter(lambda record: False)
with pytest.raises(DebugError):
while True:
queue_manager.publish_response(next(requests), callback)

View File

@ -1,32 +0,0 @@
import json
from unittest.mock import patch
from pyinfra.core import make_analyzer, make_payload_processor
from pyinfra.test.config import CONFIG
@patch("requests.post")
def test_analyse_returns_analysis(mock_post, storage_data, mock_response):
mock_post.return_value = mock_response
analyze = make_analyzer(CONFIG.mock_analysis_endpoint)
response = analyze(storage_data)
assert response == storage_data
@patch("requests.post")
def test_process_returns_dossier_id_file_id_predictions(
mock_post, mock_make_load_data, storage_data, mock_response, mock_payload
):
mock_post.return_value = mock_response
analyze = make_analyzer(CONFIG.mock_analysis_endpoint)
mock_load_data = mock_make_load_data
process = make_payload_processor(mock_load_data, analyze)
dossier_id, file_id, predictions = process(mock_payload)
assert dossier_id == json.loads(mock_payload)["dossierId"]
assert file_id == json.loads(mock_payload)["fileId"]
assert predictions == storage_data

View File

@ -0,0 +1,38 @@
import gzip
import json
import pytest
from pyinfra.visitor import get_object_descriptor, get_response_object_descriptor
@pytest.fixture()
def body():
return {"dossierId": "folder", "fileId": "file", "targetFileExtension": "in.gz", "responseFileExtension": "out.gz"}
@pytest.mark.parametrize("client_name", ["mock", "azure", "s3"], scope="session")
class TestVisitor:
@pytest.mark.parametrize("response_strategy_name", ["forwarding", "storage"], scope="session")
def test_given_a_input_queue_message_callback_pulls_the_data_from_storage(
self, visitor, body, storage, bucket_name
):
storage.clear_bucket(bucket_name)
storage.put_object(**get_object_descriptor(body), data=gzip.compress(b"content"))
data_received = visitor.load_data(body)
assert b"content" == data_received
@pytest.mark.parametrize("response_strategy_name", ["forwarding", "storage"], scope="session")
def test_visitor_pulls_and_processes_data(self, visitor, body, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(**get_object_descriptor(body), data=gzip.compress("2".encode()))
response_body = visitor.load_and_process(body)
assert response_body["data"] == "22"
@pytest.mark.parametrize("response_strategy_name", ["storage"], scope="session")
def test_visitor_puts_response_on_storage(self, visitor, body, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(**get_object_descriptor(body), data=gzip.compress("2".encode()))
response_body = visitor(body)
assert "data" not in response_body
assert json.loads(gzip.decompress(storage.get_object(**get_response_object_descriptor(body))))["data"] == "22"

View File

@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("client_name", ["mock", "azure", "s3"])
@pytest.mark.parametrize("client_name", ["mock", "azure", "s3"], scope="session")
class TestStorage:
def test_clearing_bucket_yields_empty_bucket(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
@ -17,26 +17,31 @@ class TestStorage:
assert not {*data_received}
def test_getting_object_put_in_bucket_is_object(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, "file", b"content")
data_received = storage.get_object(bucket_name, "file")
assert b"content" == data_received
def test_getting_nested_object_put_in_bucket_is_nested_object(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, "folder/file", b"content")
data_received = storage.get_object(bucket_name, "folder/file")
assert b"content" == data_received
def test_getting_objects_put_in_bucket_are_objects(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, "file1", b"content 1")
storage.put_object(bucket_name, "folder/file2", b"content 2")
data_received = storage.get_all_objects(bucket_name)
assert {b"content 1", b"content 2"} == {*data_received}
def test_make_bucket_produces_bucket(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.make_bucket(bucket_name)
assert storage.has_bucket(bucket_name)
def test_listing_bucket_files_yields_all_files_in_bucket(self, storage, bucket_name):
storage.clear_bucket(bucket_name)
storage.put_object(bucket_name, "file1", b"content 1")
storage.put_object(bucket_name, "file2", b"content 2")
full_names_received = storage.get_all_object_names(bucket_name)

View File

@ -1,16 +0,0 @@
"""Defines utilities for different operations on files."""
import gzip
import os
from pyinfra.config import CONFIG
def combine_dossier_id_and_file_id_and_extension(dossier_id, file_id, extension):
return f"{dossier_id}/{file_id}{extension}"
def upload_compressed_response(storage, bucket_name, dossier_id, file_id, result) -> None:
data = gzip.compress(result.encode())
path_gz = combine_dossier_id_and_file_id_and_extension(dossier_id, file_id, CONFIG.service.response.extension)
storage.put_object(bucket_name, path_gz, data)

View File

@ -1,96 +0,0 @@
import logging
import time
from functools import partial, wraps
from math import exp
from typing import Tuple, Type, Callable
class NoAttemptsLeft(Exception):
pass
class MaxTimeoutReached(Exception):
pass
class _MethodDecoratorAdaptor(object):
def __init__(self, decorator, func):
self.decorator = decorator
self.func = func
def __call__(self, *args, **kwargs):
return self.decorator(self.func)(*args, **kwargs)
def __get__(self, obj, objtype):
return partial(self.__call__, obj)
def auto_adapt_to_methods(decorator):
"""Allows you to use the same decorator on methods and functions,
hiding the self argument from the decorator."""
def adapt(func):
return _MethodDecoratorAdaptor(decorator, func)
return adapt
def max_attempts(
n_attempts: int = 5, exceptions: Tuple[Type[Exception]] = None, timeout: float = 0.1, max_timeout: float = 10
) -> Callable:
"""Function decorator that attempts to run the wrapped function a certain number of times. Timeouts increase
exponentially according to `Tₖ t eᵏ`, where `t` is the timeout factor `timeout` and `k` is the attempt number.
If ` Tᵢ > mₜ` at the `i-th` attempt, where `mₜ` is the maximum timeout, then the function raises
MaxTimeoutReached. If `k > mₐ`, where `mₐ` is the maximum number of attempts allowed, then the function
raises NoAttemptsLeft.
Args:
n_attempts: Number of times to attempt running the wrapped function.
exceptions: Exceptions to catch for a re-attempt.
timeout: Timeout factor in seconds.
max_timeout: Maximum allowed timeout.
Raises:
MaxTimeoutReached
NoAttemptsLeft
Returns:
Wrapped function.
"""
if not exceptions:
exceptions = (Exception,)
assert isinstance(exceptions, tuple)
@auto_adapt_to_methods
def inner(func):
@wraps(func)
def inner(*args, **kwargs):
def run_attempt(attempt, timeout_aggr=0):
if attempt:
try:
return func(*args, **kwargs)
except exceptions as err:
attempt_num = n_attempts - attempt + 1
next_timeout = timeout * exp(attempt_num - 1) # start with timeout * e^0 = timeout
logging.warn(f"{func.__name__} failed; attempt {attempt_num} of {n_attempts}")
time_left = max(0, max_timeout - timeout_aggr)
if time_left:
sleep_for = min(next_timeout, time_left)
time.sleep(sleep_for)
return run_attempt(attempt - 1, timeout_aggr + sleep_for)
else:
logging.exception(err)
raise MaxTimeoutReached(
f"{func.__name__} reached maximum timeout ({max_timeout}) after {attempt_num} attempts."
)
else:
raise NoAttemptsLeft(f"{func.__name__} failed {n_attempts} times; all attempts expended.")
return run_attempt(n_attempts)
return inner
return inner

88
pyinfra/visitor.py Normal file
View File

@ -0,0 +1,88 @@
import abc
import gzip
import json
import logging
from operator import itemgetter
from typing import Callable
from pyinfra.config import CONFIG
from pyinfra.exceptions import DataLoadingFailure
from pyinfra.storage.storage import Storage
def get_object_name(body):
dossier_id, file_id, target_file_extension = itemgetter("dossierId", "fileId", "targetFileExtension")(body)
object_name = f"{dossier_id}/{file_id}.{target_file_extension}"
return object_name
def get_response_object_name(body):
dossier_id, file_id, response_file_extension = itemgetter("dossierId", "fileId", "responseFileExtension")(body)
object_name = f"{dossier_id}/{file_id}.{response_file_extension}"
return object_name
def get_object_descriptor(body):
return {"bucket_name": CONFIG.storage.bucket, "object_name": get_object_name(body)}
def get_response_object_descriptor(body):
return {"bucket_name": CONFIG.storage.bucket, "object_name": get_response_object_name(body)}
class ResponseStrategy(abc.ABC):
@abc.abstractmethod
def handle_response(self, body):
pass
def __call__(self, body):
return self.handle_response(body)
class StorageStrategy(ResponseStrategy):
def __init__(self, storage):
self.storage = storage
def handle_response(self, body):
self.storage.put_object(**get_response_object_descriptor(body), data=gzip.compress(json.dumps(body).encode()))
body.pop("data")
return body
class ForwardingStrategy(ResponseStrategy):
def handle_response(self, body):
return body
class QueueVisitor:
def __init__(self, storage: Storage, callback: Callable, response_strategy):
self.storage = storage
self.callback = callback
self.response_strategy = response_strategy
def load_data(self, body):
def download():
logging.debug(f"Downloading {object_descriptor}...")
data = self.storage.get_object(**object_descriptor)
logging.debug(f"Downloaded {object_descriptor}.")
return data
object_descriptor = get_object_descriptor(body)
try:
return gzip.decompress(download())
except Exception as err:
logging.warning(f"Loading data from storage failed for {object_descriptor}.")
raise DataLoadingFailure() from err
def process_data(self, data, body):
return self.callback({**body, "data": data})
def load_and_process(self, body):
data = self.process_data(self.load_data(body), body)
result_body = {**body, "data": data}
return result_body
def __call__(self, body):
result_body = self.load_and_process(body)
return self.response_strategy(result_body)

View File

@ -7,7 +7,6 @@ from tqdm import tqdm
from pyinfra.config import CONFIG
from pyinfra.storage.storages import get_s3_storage
from pyinfra.utils.file import combine_dossier_id_and_file_id_and_extension
def parse_args():
@ -29,9 +28,7 @@ def parse_args():
def add_file_compressed(storage, bucket_name, dossier_id, path) -> None:
path_gz = combine_dossier_id_and_file_id_and_extension(
dossier_id, Path(path).stem, CONFIG.storage.target_file_extension
)
path_gz = combine_dossier_id_and_file_id_and_extension(dossier_id, Path(path).stem, ".ORIGIN.pdf.gz")
with open(path, "rb") as f:
data = gzip.compress(f.read())
@ -60,3 +57,13 @@ if __name__ == "__main__":
elif args.command == "purge":
storage.clear_bucket(bucket_name)
def combine_dossier_id_and_file_id_and_extension(dossier_id, file_id, extension):
return f"{dossier_id}/{file_id}{extension}"
def upload_compressed_response(storage, bucket_name, dossier_id, file_id, result) -> None:
data = gzip.compress(result.encode())
path_gz = combine_dossier_id_and_file_id_and_extension(dossier_id, file_id, CONFIG.service.response.extension)
storage.put_object(bucket_name, path_gz, data)

View File

@ -1,7 +1,8 @@
import json
import pika
from pyinfra.config import CONFIG
from pyinfra.rabbitmq import make_channel, declare_queue, make_connection
from pyinfra.storage.storages import get_s3_storage
@ -10,7 +11,31 @@ def build_message_bodies():
for bucket_name, pdf_name in storage.get_all_object_names(CONFIG.storage.bucket):
file_id = pdf_name.split(".")[0]
dossier_id, file_id = file_id.split("/")
yield json.dumps({"dossierId": dossier_id, "fileId": file_id}).encode()
yield json.dumps(
{
"dossierId": dossier_id,
"fileId": file_id,
"targetFileExtension": "ORIGIN.pdf.gz",
"responseFileExtension": "detr.json.gz",
}
).encode()
def make_channel(connection) -> pika.adapters.blocking_connection.BlockingChannel:
channel = connection.channel()
channel.basic_qos(prefetch_count=CONFIG.rabbitmq.prefetch_count)
return channel
def declare_queue(channel, queue: str):
args = {"x-dead-letter-exchange": "", "x-dead-letter-routing-key": CONFIG.rabbitmq.queues.dead_letter}
return channel.queue_declare(queue=queue, auto_delete=False, arguments=args)
def make_connection() -> pika.BlockingConnection:
parameters = read_connection_params()
connection = pika.BlockingConnection(parameters)
return connection
if __name__ == "__main__":
@ -27,3 +52,14 @@ if __name__ == "__main__":
for method_frame, _, body in channel.consume(queue=CONFIG.rabbitmq.queues.output):
print(f"Received {json.loads(body)}")
channel.basic_ack(method_frame.delivery_tag)
def read_connection_params():
credentials = pika.PlainCredentials(CONFIG.rabbitmq.user, CONFIG.rabbitmq.password)
parameters = pika.ConnectionParameters(
host=CONFIG.rabbitmq.host,
port=CONFIG.rabbitmq.port,
heartbeat=CONFIG.rabbitmq.heartbeat,
credentials=credentials,
)
return parameters

View File

@ -1,60 +1,65 @@
import logging
from multiprocessing import Process
import pika
import requests
from retry import retry
from pyinfra.callback import (
make_retry_callback_for_output_queue,
make_retry_callback,
make_callback_for_output_queue,
)
from pyinfra.config import CONFIG
from pyinfra.consume import consume, ConsumerError
from pyinfra.core import make_payload_processor, make_storage_data_loader, make_analyzer
from pyinfra.config import CONFIG, make_art
from pyinfra.exceptions import AnalysisFailure, ConsumerError
from pyinfra.flask import run_probing_webserver, set_up_probing_webserver
from pyinfra.queue.consumer import Consumer
from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager
from pyinfra.storage.storages import get_storage
from pyinfra.visitor import QueueVisitor, StorageStrategy
def republish(channel, body, n_current_attempts):
channel.basic_publish(
exchange="",
routing_key=CONFIG.rabbitmq.queues.input,
body=body,
properties=pika.BasicProperties(headers={"x-retry-count": n_current_attempts}),
)
def make_callback(analysis_endpoint):
def callback(message):
def perform_operation(operation):
endpoint = f"{analysis_endpoint}/{operation}"
try:
logging.debug(f"Requesting analysis from {endpoint}...")
analysis_response = requests.post(endpoint, data=message["data"])
analysis_response.raise_for_status()
analysis_response = analysis_response.json()
logging.debug(f"Received response.")
return analysis_response
except Exception as err:
logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.")
raise AnalysisFailure() from err
def make_callback():
load_data = make_storage_data_loader(get_storage(CONFIG.storage.backend), CONFIG.storage.bucket)
analyze_file = make_analyzer(CONFIG.rabbitmq.callback.analysis_endpoint)
json_wrapped_body_processor = make_payload_processor(load_data, analyze_file)
if CONFIG.rabbitmq.callback.retry.enabled:
retry_callback = make_retry_callback(republish, max_attempts=CONFIG.rabbitmq.callback.retry.max_attempts)
callback = make_retry_callback_for_output_queue(
json_wrapped_body_processor=json_wrapped_body_processor,
output_queue_name=CONFIG.rabbitmq.queues.output,
retry_callback=retry_callback,
)
else:
callback = make_callback_for_output_queue(
json_wrapped_body_processor=json_wrapped_body_processor, output_queue_name=CONFIG.rabbitmq.queues.output
)
operations = message.get("operations", ["/"])
results = map(perform_operation, operations)
result = dict(zip(operations, results))
return result
return callback
def main():
# TODO: implement meaningful checks
logging.info(make_art())
webserver = Process(target=run_probing_webserver, args=(set_up_probing_webserver(),))
logging.info("Starting webserver...")
webserver.start()
callback = make_callback(CONFIG.rabbitmq.callback.analysis_endpoint)
storage = get_storage(CONFIG.storage.backend)
response_strategy = StorageStrategy(storage)
visitor = QueueVisitor(storage, callback, response_strategy)
queue_manager = PikaQueueManager(CONFIG.rabbitmq.queues.input, CONFIG.rabbitmq.queues.output)
@retry(ConsumerError, tries=3, delay=5, jitter=(1, 3))
def consume():
try:
consumer = Consumer(visitor, queue_manager)
consumer.consume_and_publish()
except Exception as err:
raise ConsumerError from err
try:
consume(CONFIG.rabbitmq.queues.input, make_callback())
consume()
except KeyboardInterrupt:
pass
except ConsumerError: