WIP: add callback factory and update example scripts

This commit is contained in:
Julius Unverfehrt 2024-01-18 17:10:04 +01:00
parent 6802bf5960
commit b7f860f36b
13 changed files with 102 additions and 344 deletions

View File

@ -7,7 +7,7 @@ from funcy import identity
from prometheus_client import generate_latest, CollectorRegistry, REGISTRY, Summary
from starlette.responses import Response
from pyinfra.config.validation import validate_settings, prometheus_validators
from pyinfra.config.validation import prometheus_validators, validate_settings
def add_prometheus_endpoint(app: FastAPI, registry: CollectorRegistry = REGISTRY) -> FastAPI:

View File

@ -1,199 +0,0 @@
from dataclasses import dataclass
from functools import singledispatch, partial
from funcy import project, complement
from itertools import chain
from operator import itemgetter
from typing import Union, Sized, Callable, List
from pyinfra.config import Config
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
@dataclass
class QueueMessagePayload:
"""Default one-to-one payload, where the message contains the absolute file paths for the target and response files,
that have to be acquired from the storage."""
target_file_path: str
response_file_path: str
target_file_type: Union[str, None]
target_compression_type: Union[str, None]
response_file_type: Union[str, None]
response_compression_type: Union[str, None]
x_tenant_id: Union[str, None]
processing_kwargs: dict
@dataclass
class LegacyQueueMessagePayload(QueueMessagePayload):
"""Legacy one-to-one payload, where the message contains the dossier and file ids, and the file extensions that have
to be used to construct the absolute file paths for the target and response files, that have to be acquired from the
storage."""
dossier_id: str
file_id: str
target_file_extension: str
response_file_extension: str
class QueueMessagePayloadParser:
def __init__(self, payload_matcher2parse_strategy: dict):
self.payload_matcher2parse_strategy = payload_matcher2parse_strategy
def __call__(self, payload: dict) -> QueueMessagePayload:
for payload_matcher, parse_strategy in self.payload_matcher2parse_strategy.items():
if payload_matcher(payload):
return parse_strategy(payload)
def get_queue_message_payload_parser(config: Config) -> QueueMessagePayloadParser:
file_extension_parser = make_file_extension_parser(config.allowed_file_types, config.allowed_compression_types)
payload_matcher2parse_strategy = get_payload_matcher2parse_strategy(
file_extension_parser, config.allowed_processing_parameters
)
return QueueMessagePayloadParser(payload_matcher2parse_strategy)
def get_payload_matcher2parse_strategy(parse_file_extensions: Callable, allowed_processing_parameters: List[str]):
return {
is_legacy_payload: partial(
parse_legacy_queue_message_payload,
parse_file_extensions=parse_file_extensions,
allowed_processing_parameters=allowed_processing_parameters,
),
complement(is_legacy_payload): partial(
parse_queue_message_payload,
parse_file_extensions=parse_file_extensions,
allowed_processing_parameters=allowed_processing_parameters,
),
}
def is_legacy_payload(payload: dict) -> bool:
return {"dossierId", "fileId", "targetFileExtension", "responseFileExtension"}.issubset(payload.keys())
def parse_queue_message_payload(
payload: dict,
parse_file_extensions: Callable,
allowed_processing_parameters: List[str],
) -> QueueMessagePayload:
target_file_path, response_file_path = itemgetter("targetFilePath", "responseFilePath")(payload)
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
map(parse_file_extensions, [target_file_path, response_file_path])
)
x_tenant_id = payload.get("X-TENANT-ID")
processing_kwargs = project(payload, allowed_processing_parameters)
return QueueMessagePayload(
target_file_path=target_file_path,
response_file_path=response_file_path,
target_file_type=target_file_type,
target_compression_type=target_compression_type,
response_file_type=response_file_type,
response_compression_type=response_compression_type,
x_tenant_id=x_tenant_id,
processing_kwargs=processing_kwargs,
)
def parse_legacy_queue_message_payload(
payload: dict,
parse_file_extensions: Callable,
allowed_processing_parameters: List[str],
) -> LegacyQueueMessagePayload:
dossier_id, file_id, target_file_extension, response_file_extension = itemgetter(
"dossierId", "fileId", "targetFileExtension", "responseFileExtension"
)(payload)
target_file_path = f"{dossier_id}/{file_id}.{target_file_extension}"
response_file_path = f"{dossier_id}/{file_id}.{response_file_extension}"
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
map(parse_file_extensions, [target_file_extension, response_file_extension])
)
x_tenant_id = payload.get("X-TENANT-ID")
processing_kwargs = project(payload, allowed_processing_parameters)
return LegacyQueueMessagePayload(
dossier_id=dossier_id,
file_id=file_id,
x_tenant_id=x_tenant_id,
target_file_extension=target_file_extension,
response_file_extension=response_file_extension,
target_file_type=target_file_type,
target_compression_type=target_compression_type,
response_file_type=response_file_type,
response_compression_type=response_compression_type,
target_file_path=target_file_path,
response_file_path=response_file_path,
processing_kwargs=processing_kwargs,
)
@singledispatch
def format_service_processing_result_for_storage(payload: QueueMessagePayload, result: Sized) -> dict:
raise NotImplementedError("Unsupported payload type")
@format_service_processing_result_for_storage.register(LegacyQueueMessagePayload)
def _(payload: LegacyQueueMessagePayload, result: Sized) -> dict:
processing_kwargs = payload.processing_kwargs or {}
x_tenant_id = {"X-TENANT-ID": payload.x_tenant_id} if payload.x_tenant_id else {}
return {
"dossierId": payload.dossier_id,
"fileId": payload.file_id,
"targetFileExtension": payload.target_file_extension,
"responseFileExtension": payload.response_file_extension,
**x_tenant_id,
**processing_kwargs,
"data": result,
}
@format_service_processing_result_for_storage.register(QueueMessagePayload)
def _(payload: QueueMessagePayload, result: Sized) -> dict:
processing_kwargs = payload.processing_kwargs or {}
x_tenant_id = {"X-TENANT-ID": payload.x_tenant_id} if payload.x_tenant_id else {}
return {
"targetFilePath": payload.target_file_path,
"responseFilePath": payload.response_file_path,
**x_tenant_id,
**processing_kwargs,
"data": result,
}
@singledispatch
def format_to_queue_message_response_body(queue_message_payload: QueueMessagePayload) -> dict:
raise NotImplementedError("Unsupported payload type")
@format_to_queue_message_response_body.register(LegacyQueueMessagePayload)
def _(payload: LegacyQueueMessagePayload) -> dict:
processing_kwargs = payload.processing_kwargs or {}
x_tenant_id = {"X-TENANT-ID": payload.x_tenant_id} if payload.x_tenant_id else {}
return {"dossierId": payload.dossier_id, "fileId": payload.file_id, **x_tenant_id, **processing_kwargs}
@format_to_queue_message_response_body.register(QueueMessagePayload)
def _(payload: QueueMessagePayload) -> dict:
processing_kwargs = payload.processing_kwargs or {}
x_tenant_id = {"X-TENANT-ID": payload.x_tenant_id} if payload.x_tenant_id else {}
return {
"targetFilePath": payload.target_file_path,
"responseFilePath": payload.response_file_path,
**x_tenant_id,
**processing_kwargs,
}

View File

@ -1,97 +0,0 @@
from kn_utils.logging import logger
from dataclasses import asdict
from typing import Callable, List
from pyinfra.config import get_config, Config
from pyinfra.payload_processing.monitor import get_monitor_from_config
from pyinfra.payload_processing.payload import (
QueueMessagePayloadParser,
get_queue_message_payload_parser,
format_service_processing_result_for_storage,
format_to_queue_message_response_body,
QueueMessagePayload,
)
from pyinfra.storage.storage import make_downloader, make_uploader
from pyinfra.storage.storage_provider import StorageProvider
class PayloadProcessor:
def __init__(
self,
storage_provider: StorageProvider,
payload_parser: QueueMessagePayloadParser,
data_processor: Callable,
):
"""Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps.
Args:
storage_provider: Storage manager that connects to the storage, using the tenant id if provided
payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object
data_processor: The analysis function to be called with the downloaded file
NOTE: The result of the analysis function has to be an instance of `Sized`, e.g. a dict or a list to be
able to upload it and to be able to monitor the processing time.
"""
self.parse_payload = payload_parser
self.provide_storage = storage_provider
self.process_data = data_processor
def __call__(self, queue_message_payload: dict) -> dict:
"""Processes a queue message payload.
The steps executed are:
1. Download the file specified in the message payload from the storage
2. Process the file with the analysis function
3. Upload the result to the storage
4. Return the payload for a response queue message
Args:
queue_message_payload: The payload of a queue message. The payload is expected to be a dict with the
following keys:
targetFilePath, responseFilePath
OR
dossierId, fileId, targetFileExtension, responseFileExtension
Returns:
The payload for a response queue message, containing only the request payload.
"""
return self._process(queue_message_payload)
def _process(self, queue_message_payload: dict) -> dict:
payload: QueueMessagePayload = self.parse_payload(queue_message_payload)
logger.info(f"Processing {payload.__class__.__name__} ...")
logger.debug(f"Payload contents: {asdict(payload)} ...")
storage, storage_info = self.provide_storage(payload.x_tenant_id)
download_file_to_process = make_downloader(
storage, storage_info.bucket_name, payload.target_file_type, payload.target_compression_type
)
upload_processing_result = make_uploader(
storage, storage_info.bucket_name, payload.response_file_type, payload.response_compression_type
)
data = download_file_to_process(payload.target_file_path)
result: List[dict] = self.process_data(data, **payload.processing_kwargs)
formatted_result = format_service_processing_result_for_storage(payload, result)
upload_processing_result(payload.response_file_path, formatted_result)
return format_to_queue_message_response_body(payload)
def make_payload_processor(data_processor: Callable, config: Config = None) -> PayloadProcessor:
"""Creates a payload processor."""
config = config or get_config()
storage_provider = StorageProvider(config)
monitor = get_monitor_from_config(config)
payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config)
data_processor = monitor(data_processor)
return PayloadProcessor(
storage_provider,
payload_parser,
data_processor,
)

36
pyinfra/queue/callback.py Normal file
View File

@ -0,0 +1,36 @@
from typing import Callable, Union
from dynaconf import Dynaconf
from kn_utils.logging import logger
from pyinfra.storage.connection import get_storage
from pyinfra.storage.utils import download_data_as_specified_in_message, upload_data_as_specified_in_message
DataProcessor = Callable[[Union[dict, bytes], dict], dict]
def make_payload_processor(data_processor: DataProcessor, settings: Dynaconf):
"""Default callback for processing queue messages.
Data will be downloaded from the storage as specified in the message. If a tenant id is specified, the storage
will be configured to use that tenant id, otherwise the storage is configured as specified in the settings.
The data is the passed to the dataprocessor, together with the message. The dataprocessor should return a
json-dump-able object. This object is then uploaded to the storage as specified in the message.
The response message is just the original message.
Adapt as needed.
"""
def inner(queue_message_payload: dict) -> dict:
logger.info(f"Processing payload...")
storage = get_storage(settings, queue_message_payload.get("X-TENANT-ID"))
data = download_data_as_specified_in_message(storage, queue_message_payload)
result = data_processor(data, queue_message_payload)
upload_data_as_specified_in_message(storage, queue_message_payload, result)
return queue_message_payload
return inner

View File

@ -13,7 +13,7 @@ from kn_utils.logging import logger
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
from retry import retry
from pyinfra.config.validation import validate_settings, queue_manager_validators
from pyinfra.config.validation import queue_manager_validators, validate_settings
pika_logger = logging.getLogger("pika")
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
@ -175,6 +175,7 @@ class QueueManager:
result,
properties=pika.BasicProperties(headers=filtered_message_headers),
)
# FIXME: publish doesnt work in example script, explore, adapt, overcome
logger.info(f"Published result to queue {self.output_queue}.")
channel.basic_ack(delivery_tag=method.delivery_tag)

View File

@ -8,7 +8,7 @@ from pyinfra.storage.storages.azure import get_azure_storage_from_settings
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
from pyinfra.storage.storages.storage import Storage
from pyinfra.utils.cipher import decrypt
from pyinfra.config.validation import validate_settings, storage_validators, multi_tenant_storage_validators
from pyinfra.config.validation import storage_validators, multi_tenant_storage_validators, validate_settings
def get_storage(settings: Dynaconf, tenant_id: str = None) -> Storage:

View File

@ -8,7 +8,7 @@ from minio import Minio
from retry import retry
from pyinfra.storage.storages.storage import Storage
from pyinfra.config.validation import validate_settings, s3_storage_validators
from pyinfra.config.validation import s3_storage_validators, validate_settings
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint

View File

@ -5,7 +5,7 @@ import uvicorn
from dynaconf import Dynaconf
from fastapi import FastAPI
from pyinfra.config.validation import validate_settings, webserver_validators
from pyinfra.config.validation import webserver_validators, validate_settings
def create_webserver_thread_from_settings(app: FastAPI, settings: Dynaconf) -> threading.Thread:

View File

@ -1,22 +1,20 @@
import gzip
import json
import logging
from operator import itemgetter
import pika
from kn_utils.logging import logger
from pyinfra.config import get_config
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
from pyinfra.storage.storages.s3 import get_s3_storage_from_config
from pyinfra.config.loader import load_settings
from pyinfra.queue.manager import QueueManager
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
CONFIG = get_config()
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
settings = load_settings()
def upload_json_and_make_message_body():
bucket = CONFIG.storage_bucket
bucket = settings.storage.s3.bucket
dossier_id, file_id, suffix = "dossier", "file", "json.gz"
content = {
"numberOfPages": 7,
@ -26,10 +24,10 @@ def upload_json_and_make_message_body():
object_name = f"{dossier_id}/{file_id}.{suffix}"
data = gzip.compress(json.dumps(content).encode("utf-8"))
storage = get_s3_storage_from_config(CONFIG)
if not storage.has_bucket(bucket):
storage.make_bucket(bucket)
storage.put_object(bucket, object_name, data)
storage = get_s3_storage_from_settings(settings)
if not storage.has_bucket():
storage.make_bucket()
storage.put_object(object_name, data)
message_body = {
"dossierId": dossier_id,
@ -41,31 +39,31 @@ def upload_json_and_make_message_body():
def main():
development_queue_manager = DevelopmentQueueManager(CONFIG)
development_queue_manager.clear_queues()
queue_manager = QueueManager(settings)
queue_manager.purge_queues()
message = upload_json_and_make_message_body()
development_queue_manager.publish_request(message, pika.BasicProperties(headers={"X-TENANT-ID": "redaction"}))
logger.info(f"Put {message} on {CONFIG.request_queue}")
queue_manager.publish_message_to_input_queue(message)
logger.info(f"Put {message} on {settings.rabbitmq.input_queue}.")
storage = get_s3_storage_from_config(CONFIG)
for method_frame, properties, body in development_queue_manager._channel.consume(
queue=CONFIG.response_queue, inactivity_timeout=15
storage = get_s3_storage_from_settings(settings)
for method_frame, properties, body in queue_manager.channel.consume(
queue=settings.rabbitmq.output_queue, inactivity_timeout=15
):
if not body:
break
response = json.loads(body)
logger.info(f"Received {response}")
logger.info(f"Message headers: {properties.headers}")
development_queue_manager._channel.basic_ack(method_frame.delivery_tag)
queue_manager.channel.basic_ack(method_frame.delivery_tag)
dossier_id, file_id = itemgetter("dossierId", "fileId")(response)
suffix = message["responseFileExtension"]
print(f"{dossier_id}/{file_id}.{suffix}")
result = storage.get_object(CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{suffix}")
result = storage.get_object(f"{dossier_id}/{file_id}.{suffix}")
result = json.loads(gzip.decompress(result))
logger.info(f"Contents of result on storage: {result}")
development_queue_manager.close_channel()
queue_manager.stop_consuming()
if __name__ == "__main__":

View File

@ -1,24 +1,42 @@
import logging
import time
from pyinfra.config import get_config
from pyinfra.payload_processing.processor import make_payload_processor
from pyinfra.queue.queue_manager import QueueManager
from fastapi import FastAPI
from pyinfra.config.loader import load_settings
from pyinfra.monitor.prometheus import make_prometheus_processing_time_decorator_from_settings, add_prometheus_endpoint
from pyinfra.queue.callback import make_payload_processor
from pyinfra.queue.manager import QueueManager
from pyinfra.webserver.utils import create_webserver_thread_from_settings
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
settings = load_settings()
def json_processor_mock(data: dict):
@make_prometheus_processing_time_decorator_from_settings(settings)
def json_processor_mock(_data: dict, _message: dict) -> dict:
time.sleep(5)
return [{"result1": "result1"}, {"result2": "result2"}]
return {"result1": "result1"}
def main():
logger.info("Start consuming...")
queue_manager = QueueManager(get_config())
queue_manager.start_consuming(make_payload_processor(json_processor_mock))
app = FastAPI()
app = add_prometheus_endpoint(app)
queue_manager = QueueManager(settings)
@app.get("/ready")
@app.get("/health")
def check_health():
return queue_manager.is_ready()
webserver_thread = create_webserver_thread_from_settings(app, settings)
webserver_thread.start()
callback = make_payload_processor(json_processor_mock, settings)
queue_manager.start_consuming(callback)
if __name__ == "__main__":

View File

@ -1,8 +1,20 @@
import pytest
from pyinfra.config.loader import load_settings
from pyinfra.storage.connection import get_storage_from_settings
@pytest.fixture(scope="session")
def settings():
return load_settings()
@pytest.fixture(scope="class")
def storage(storage_backend, settings):
settings.storage.backend = storage_backend
storage = get_storage_from_settings(settings)
storage.make_bucket()
yield storage
storage.clear_bucket()

View File

@ -5,23 +5,12 @@ from time import sleep
import pytest
from fastapi import FastAPI
from pyinfra.storage.connection import get_storage_from_settings, get_storage_from_tenant_id
from pyinfra.storage.connection import get_storage_from_tenant_id
from pyinfra.storage.utils import download_data_as_specified_in_message, upload_data_as_specified_in_message
from pyinfra.utils.cipher import encrypt
from pyinfra.webserver.utils import create_webserver_thread
@pytest.fixture(scope="class")
def storage(storage_backend, settings):
settings.storage.backend = storage_backend
storage = get_storage_from_settings(settings)
storage.make_bucket()
yield storage
storage.clear_bucket()
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
class TestStorage:
def test_clearing_bucket_yields_empty_bucket(self, storage):