feat: add async_v2

This commit is contained in:
Jonathan Kössler 2024-07-12 12:12:55 +02:00
parent a5162d5bf0
commit 8844df44ce
6 changed files with 375 additions and 198 deletions

View File

@ -3,12 +3,14 @@ import asyncio
from dynaconf import Dynaconf
from fastapi import FastAPI
from kn_utils.logging import logger
from threading import Thread
# from threading import Thread
from pyinfra.config.loader import get_pyinfra_validators, validate_settings
# from pyinfra.queue.threaded_tenants import ServiceQueueManager, TenantQueueManager
from pyinfra.queue.async_tenants_v2 import RabbitMQConfig, RabbitMQHandler
from pyinfra.queue.callback import Callback
from pyinfra.queue.threaded_tenants import ServiceQueueManager, TenantQueueManager
from pyinfra.queue.async_tenants import AsyncQueueManager
from pyinfra.utils.opentelemetry import instrument_pika, setup_trace, instrument_app
from pyinfra.utils.opentelemetry import instrument_app, instrument_pika, setup_trace
from pyinfra.webserver.prometheus import (
add_prometheus_endpoint,
make_prometheus_processing_time_decorator_from_settings,
@ -19,6 +21,22 @@ from pyinfra.webserver.utils import (
)
def get_rabbitmq_config(settings: Dynaconf):
return RabbitMQConfig(
host=settings.rabbitmq.host,
port=settings.rabbitmq.port,
username=settings.rabbitmq.username,
password=settings.rabbitmq.password,
heartbeat=settings.rabbitmq.heartbeat,
input_queue_prefix=settings.rabbitmq.service_request_queue_prefix,
tenant_exchange_name=settings.rabbitmq.service_response_queue_prefix,
service_request_exchange_name=settings.rabbitmq.service_request_exchange_name,
service_response_exchange_name=settings.rabbitmq.service_response_exchange_name,
service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name,
queue_expiration_time=settings.rabbitmq.queue_expiration_time,
)
def start_standard_queue_consumer(
callback: Callback,
settings: Dynaconf,
@ -37,8 +55,8 @@ def start_standard_queue_consumer(
app = app or FastAPI()
tenant_manager = TenantQueueManager(settings)
service_manager = ServiceQueueManager(settings)
# tenant_manager = TenantQueueManager(settings)
# service_manager = ServiceQueueManager(settings)
if settings.metrics.prometheus.enabled:
logger.info("Prometheus metrics enabled.")
@ -52,20 +70,24 @@ def start_standard_queue_consumer(
instrument_app(app)
# manager = AsyncQueueManager(settings=settings, message_processor=callback)
config = get_rabbitmq_config(settings)
manager = RabbitMQHandler(
config=config, tenant_service_url=settings.storage.tenant_server.endpoint, message_processor=callback
)
app = add_health_check_endpoint(app, service_manager.is_ready)
# app = add_health_check_endpoint(app, manager.is_ready)
# app = add_health_check_endpoint(app, service_manager.is_ready)
app = add_health_check_endpoint(app, manager.is_ready)
webserver_thread = create_webserver_thread_from_settings(app, settings)
webserver_thread.start()
thread_t = Thread(target=tenant_manager.start_consuming, daemon=True)
thread_s = Thread(target=service_manager.start_sequential_basic_get, args=(callback,), daemon=True)
# thread_t = Thread(target=tenant_manager.start_consuming, daemon=True)
# thread_s = Thread(target=service_manager.start_sequential_basic_get, args=(callback,), daemon=True)
thread_t.start()
thread_s.start()
# thread_t.start()
# thread_s.start()
thread_t.join()
thread_s.join()
# asyncio.run(manager.start_processing())
# thread_t.join()
# thread_s.join()
asyncio.run(manager.run())

View File

@ -1,28 +1,31 @@
import aiormq
import asyncio
import aio_pika
import concurrent.futures
import requests
import datetime
import json
from aio_pika import Message, DeliveryMode
from aio_pika.abc import AbstractIncomingMessage
from dynaconf import Dynaconf
import time
import uuid
from typing import Callable, Union
import aio_pika
import aiormq
import requests
from aio_pika import DeliveryMode, Message
from aio_pika.abc import AbstractIncomingMessage
from dynaconf import Dynaconf
from kn_utils.logging import logger
from pyinfra.config.loader import validate_settings
from pyinfra.config.loader import (
load_settings,
local_pyinfra_root_path,
validate_settings,
)
from pyinfra.config.validators import queue_manager_validators
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
from pyinfra.queue.callback import make_download_process_upload_callback
MessageProcessor = Callable[[dict], dict]
class AsyncQueueManager:
def __init__(self, settings: Dynaconf, message_processor: Callable = None) -> None:
validate_settings(settings, queue_manager_validators)
@ -36,12 +39,12 @@ class AsyncQueueManager:
self.connection_sleep = settings.rabbitmq.connection_sleep
self.queue_expiration_time = settings.rabbitmq.queue_expiration_time
self.tenant_created_queue_name = self.get_tenant_created_queue_name(settings)
self.tenant_deleted_queue_name = self.get_tenant_deleted_queue_name(settings)
self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings)
self.tenant_exchange_name = settings.rabbitmq.tenant_exchange_name
self.service_request_exchange_name = settings.rabbitmq.service_request_exchange_name
self.service_response_exchange_name = settings.rabbitmq.service_response_exchange_name
@ -50,18 +53,16 @@ class AsyncQueueManager:
self.service_dlq_name = settings.rabbitmq.service_dlq_name
@staticmethod
def get_connection_params(settings: Dynaconf):
return {
"host": settings.rabbitmq.host,
"port": settings.rabbitmq.port,
"login": settings.rabbitmq.username,
"password":settings.rabbitmq.password,
"client_properties": {"heartbeat": settings.rabbitmq.heartbeat}
"password": settings.rabbitmq.password,
"client_properties": {"heartbeat": settings.rabbitmq.heartbeat},
}
def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> set:
response = requests.get(tenant_endpoint_url, timeout=10)
response.raise_for_status() # Raise an HTTPError for bad responses
@ -70,7 +71,7 @@ class AsyncQueueManager:
tenants = {tenant["tenantId"] for tenant in response.json()}
return tenants
return set()
def get_tenant_created_queue_name(self, settings: Dynaconf) -> str:
return self.get_queue_name_with_suffix(
suffix=settings.rabbitmq.tenant_created_event_queue_suffix, pod_name=settings.kubernetes.pod_name
@ -96,7 +97,7 @@ class AsyncQueueManager:
def get_default_queue_name(self):
raise NotImplementedError("Queue name method not implemented")
async def is_ready(self) -> bool:
await self.connect()
return self.channel.is_open
@ -108,7 +109,9 @@ class AsyncQueueManager:
for tenant_id in self.active_tenants:
service_request_queue = await self.channel.get_queue(f"{self.service_request_queue_prefix}_{tenant_id}")
await service_request_queue.purge()
service_response_queue = await self.channel.get_queue(f"{self.service_response_queue_prefix}_{tenant_id}")
service_response_queue = await self.channel.get_queue(
f"{self.service_response_queue_prefix}_{tenant_id}"
)
await service_response_queue.purge()
logger.info("Queues purged.")
except aio_pika.exceptions.ChannelInvalidStateError:
@ -132,12 +135,15 @@ class AsyncQueueManager:
await asyncio.gather(tenant_events, service_events)
async def initialize_queues(self):
await self.channel.set_qos(prefetch_count=1)
service_request_exchange = await self.channel.declare_exchange(name=self.service_request_exchange_name, type=aio_pika.ExchangeType.DIRECT)
service_response_exchange = await self.channel.declare_exchange(name=self.service_response_exchange_name, type=aio_pika.ExchangeType.DIRECT)
service_request_exchange = await self.channel.declare_exchange(
name=self.service_request_exchange_name, type=aio_pika.ExchangeType.DIRECT, durable=True
)
service_response_exchange = await self.channel.declare_exchange(
name=self.service_response_exchange_name, type=aio_pika.ExchangeType.DIRECT, durable=True
)
for tenant_id in self.active_tenants:
request_queue_name = f"{self.service_request_queue_prefix}_{tenant_id}"
@ -164,26 +170,30 @@ class AsyncQueueManager:
},
)
await response_queue.bind(exchange=service_response_exchange, routing_key=tenant_id)
async def handle_tenant_events(self):
# Declare the topic exchange for tenant events
exchange = await self.channel.declare_exchange(
self.tenant_exchange_name, aio_pika.ExchangeType.TOPIC
self.tenant_exchange_name, aio_pika.ExchangeType.TOPIC, durable=True
)
# Declare a queue for receiving tenant events
queue = await self.channel.declare_queue("tenant_events_queue", arguments={
queue = await self.channel.declare_queue(
"tenant_events_queue",
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.tenant_events_dlq_name,
},
durable=True,)
durable=True,
)
await queue.bind(exchange, routing_key="tenant.*")
async with queue.iterator() as queue_iter:
async for message in queue_iter:
async with message.process(reject_on_redelivered=True):
routing_key = message.routing_key
tenant_id = message.body.decode()
message_body = json.loads(message.body.decode())
tenant_id = message_body["tenantId"]
if routing_key == "tenant.created":
# Handle tenant creation
await self.handle_tenant_created(tenant_id)
@ -222,7 +232,7 @@ class AsyncQueueManager:
exchange = await self.channel.get_exchange(self.service_request_exchange_name)
await queue.bind(exchange=exchange, routing_key=tenant_id)
self.active_tenants.add(tenant_id)
logger.info(f"Created queue for tenant {tenant_id}")
async def delete_tenant_queues(self, tenant_id):
queue_name = f"{self.service_request_queue_prefix}_{tenant_id}"
@ -245,8 +255,15 @@ class AsyncQueueManager:
async def publish_to_service_response_queue(self, tenant_id, result):
service_response_exchange = await self.channel.get_exchange(self.service_response_exchange_name)
await service_response_exchange.publish(aio_pika.Message(body=json.dumps(result).encode()),
routing_key=tenant_id,)
await service_response_exchange.publish(
Message(
body=json.dumps(result).encode(),
delivery_mode=DeliveryMode.NOT_PERSISTENT,
timestamp=datetime.datetime.now(),
message_id=str(uuid.uuid4()),
),
routing_key=tenant_id,
)
async def restart_consumers(self):
# Stop current consumers and start new ones for active tenants
@ -284,11 +301,13 @@ class AsyncQueueManager:
async def process_message_body_and_await_result(unpacked_message_body):
# Processing the message in a separate thread is necessary for the main thread pika client to be able to
# process data events (e.g. heartbeats) while the message is being processed.
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.info("Processing payload in separate thread.")
future = thread_pool_executor.submit(message_processor, unpacked_message_body)
# with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
# logger.info("Processing payload in separate thread.")
# future = thread_pool_executor.submit(message_processor, unpacked_message_body)
return future.result()
# return future.result()
return {"result": "lovely"}
async def on_message_callback(message: AbstractIncomingMessage):
logger.info(f"Received message from queue with delivery_tag {message.delivery_tag}.")
@ -312,12 +331,13 @@ class AsyncQueueManager:
)
logger.debug(f"Processing message with {filtered_message_headers=}.")
result: dict = await (
process_message_body_and_await_result({**json.loads(message.body), **filtered_message_headers}) or {}
process_message_body_and_await_result({**json.loads(message.body), **filtered_message_headers})
or {}
)
await self.publish_to_service_response_queue(tenant_id, result)
logger.info(f"Published result to queue {tenant_id}.")
await message.ack()
logger.debug(f"Message with {message.delivery_tag=} acknowledged.")
except FileNotFoundError as e:
@ -330,9 +350,8 @@ class AsyncQueueManager:
raise
return on_message_callback
async def publish_message_to_input_queue(
self, tenant_id: str, message: Union[str, bytes, dict]) -> None:
async def publish_message_to_input_queue(self, tenant_id: str, message: Union[str, bytes, dict]) -> None:
if isinstance(message, str):
message = message.encode("utf-8")
elif isinstance(message, dict):
@ -342,12 +361,19 @@ class AsyncQueueManager:
service_request_exchange = await self.channel.get_exchange(self.service_request_exchange_name)
await service_request_exchange.publish(message=Message(body=message, delivery_mode=DeliveryMode.NOT_PERSISTENT), routing_key=tenant_id)
await service_request_exchange.publish(
message=Message(
body=message,
delivery_mode=DeliveryMode.NOT_PERSISTENT,
timestamp=datetime.datetime.now(),
message_id=str(uuid.uuid4()),
),
routing_key=tenant_id,
)
logger.info(f"Published message to queue {tenant_id}.")
async def publish_message_to_tenant_created_queue(
self, message: Union[str, bytes, dict]) -> None:
async def publish_message_to_tenant_created_queue(self, message: Union[str, bytes, dict]) -> None:
if isinstance(message, str):
message = message.encode("utf-8")
elif isinstance(message, dict):
@ -356,12 +382,19 @@ class AsyncQueueManager:
await self.establish_connection()
service_request_exchange = await self.channel.get_exchange(self.tenant_exchange_name)
await service_request_exchange.publish(message=Message(body=message, delivery_mode=DeliveryMode.NOT_PERSISTENT), routing_key="tenant.created")
await service_request_exchange.publish(
message=Message(
body=message,
delivery_mode=DeliveryMode.NOT_PERSISTENT,
timestamp=datetime.datetime.now(),
message_id=str(uuid.uuid4()),
),
routing_key="tenant.created",
)
logger.info(f"Published message to queue {self.tenant_created_queue_name}.")
async def publish_message_to_tenant_deleted_queue(
self, message: Union[str, bytes, dict]) -> None:
async def publish_message_to_tenant_deleted_queue(self, message: Union[str, bytes, dict]) -> None:
if isinstance(message, str):
message = message.encode("utf-8")
elif isinstance(message, dict):
@ -370,15 +403,22 @@ class AsyncQueueManager:
await self.establish_connection()
service_request_exchange = await self.channel.get_exchange(self.tenant_exchange_name)
await service_request_exchange.publish(message=Message(body=message, delivery_mode=DeliveryMode.NOT_PERSISTENT), routing_key="tenant.delete")
await service_request_exchange.publish(
message=Message(
body=message,
delivery_mode=DeliveryMode.NOT_PERSISTENT,
timestamp=datetime.datetime.now(),
message_id=str(uuid.uuid4()),
),
routing_key="tenant.delete",
)
logger.info(f"Published message to queue {self.tenant_deleted_queue_name}.")
async def main() -> None:
import time
settings = load_settings(local_pyinfra_root_path / "config/")
callback = ""
@ -386,10 +426,10 @@ async def main() -> None:
await manager.main_loop()
while True:
time.sleep(100)
print("keep idling")
if __name__ == '__main__':
asyncio.run(main())
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,17 +1,18 @@
import asyncio
import gzip
from operator import itemgetter
from typing import Any, Callable, Dict, Set
import aiohttp
from aio_pika import connect_robust, ExchangeType, Message, IncomingMessage
from aio_pika.abc import AbstractIncomingMessage, AbstractChannel, AbstractConnection, AbstractExchange
import json
from kn_utils.logging import logger
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
import signal
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Set
settings = load_settings(local_pyinfra_root_path / "config/")
import aiohttp
from aio_pika import ExchangeType, IncomingMessage, Message, connect_robust
from aio_pika.abc import (
AbstractChannel,
AbstractConnection,
AbstractExchange,
AbstractIncomingMessage,
)
from kn_utils.logging import logger
@dataclass
@ -63,6 +64,10 @@ class RabbitMQHandler:
self.channel = await self.connection.channel()
await self.channel.set_qos(prefetch_count=1)
async def is_ready(self) -> bool:
await self.connect()
return self.channel.is_open
async def setup_exchanges(self):
self.tenant_exchange = await self.channel.declare_exchange(
self.config.tenant_exchange_name, ExchangeType.TOPIC, durable=True
@ -75,6 +80,7 @@ class RabbitMQHandler:
)
async def setup_tenant_queue(self) -> None:
# TODO: Add k8s pod_name to tenant queue name - add DLQ?
queue = await self.channel.declare_queue(
"tenant_queue",
durable=True,
@ -93,7 +99,7 @@ class RabbitMQHandler:
tenant_id = message_body["tenantId"]
routing_key = message.routing_key
if routing_key == "tenant.create":
if routing_key == "tenant.created":
await self.create_tenant_queues(tenant_id)
elif routing_key == "tenant.delete":
await self.delete_tenant_queues(tenant_id)
@ -119,16 +125,30 @@ class RabbitMQHandler:
async def delete_tenant_queues(self, tenant_id: str) -> None:
if tenant_id in self.tenant_queues:
input_queue = self.tenant_queues[tenant_id]
await input_queue.delete()
# somehow queue.delete() does not work here
await self.channel.queue_delete(f"{self.config.input_queue_prefix}_{tenant_id}")
del self.tenant_queues[tenant_id]
logger.info(f"Deleted queues for tenant {tenant_id}")
async def process_input_message(self, message: IncomingMessage) -> None:
async with message.process():
async def process_message_body_and_await_result(unpacked_message_body):
return self.message_processor(unpacked_message_body)
async with message.process(ignore_processed=True):
if message.redelivered:
logger.warning(f"Declining message with {message.delivery_tag=} due to it being redelivered.")
await message.nack(requeue=False)
return
if message.body.decode("utf-8") == "STOP":
logger.info("Received stop signal, stopping consumption...")
await message.ack()
# TODO: shutdown is probably not the right call here - align w/ Dev what should happen on stop signal
await self.shutdown()
return
try:
tenant_id = message.routing_key
message_body = json.loads(message.body.decode("utf-8"))
filtered_message_headers = (
{k: v for k, v in message.headers.items() if k.lower().startswith("x-")} if message.headers else {}
@ -136,18 +156,26 @@ class RabbitMQHandler:
logger.debug(f"Processing message with {filtered_message_headers=}.")
message_body.update(filtered_message_headers)
result = await self.message_processor(message_body)
result: dict = await (
process_message_body_and_await_result({**json.loads(message.body), **filtered_message_headers})
or {}
)
if result:
await self.publish_to_output_exchange(tenant_id, result, filtered_message_headers)
await message.ack()
logger.debug(f"Message with {message.delivery_tag=} acknowledged.")
else:
raise ValueError(f"Could not process message with {message.body=}.")
except json.JSONDecodeError:
await message.nack(requeue=False)
logger.error(f"Invalid JSON in input message: {message.body}")
except FileNotFoundError as e:
logger.warning(f"{e}, declining message.")
logger.warning(f"{e}, declining message with {message.delivery_tag=}.")
await message.nack(requeue=False)
except Exception as e:
await message.nack(requeue=False)
logger.error(f"Error processing input message: {e}", exc_info=True)
raise
@ -162,6 +190,8 @@ class RabbitMQHandler:
try:
async with aiohttp.ClientSession() as session:
async with session.get(self.tenant_service_url) as response:
# TODO: dont know if we should check for 200, could also be 2xx
# maybe handle bad requests with response.raise_for_status()
if response.status == 200:
data = await response.json()
return {tenant["tenantId"] for tenant in data}
@ -178,105 +208,37 @@ class RabbitMQHandler:
await self.create_tenant_queues(tenant_id)
async def run(self) -> None:
stop = asyncio.Event()
def signal_handler(*_):
logger.info("Signal received, shutting down...")
stop.set()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, signal_handler)
try:
await self.connect()
await self.setup_exchanges()
await self.initialize_tenant_queues()
await self.setup_tenant_queue()
logger.info("RabbitMQ handler is running. Press CTRL+C to exit.")
await asyncio.Future() # Run forever
await stop.wait() # Run until stop signal received
except asyncio.CancelledError:
logger.info("Shutting down RabbitMQ handler...")
logger.info("Operation cancelled.")
except Exception as e:
logger.error(f"An error occurred: {e}", exc_info=True)
finally:
if self.connection:
await self.connection.close()
await self.shutdown()
async def shutdown(self) -> None:
logger.info("Shutting down RabbitMQ handler...")
if self.channel:
await self.channel.close()
if self.connection:
await self.connection.close()
logger.info("RabbitMQ handler shut down successfully.")
async def dummy_message_processor(message: Dict[str, Any]) -> Dict[str, Any]:
logger.info(f"Processing message: {message}")
await asyncio.sleep(1) # Simulate processing time
storage = get_s3_storage_from_settings(settings)
tenant_id, dossier_id, file_id = itemgetter("tenantId", "dossierId", "fileId")(message)
suffix = message["responseFileExtension"]
object_name = f"{tenant_id}/{dossier_id}/{file_id}.{message['targetFileExtension']}"
original_content = json.loads(gzip.decompress(storage.get_object(object_name)))
processed_content = {
"processedPages": original_content["numberOfPages"],
"processedSectionTexts": f"Processed: {original_content['sectionTexts']}",
}
processed_object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}"
processed_data = gzip.compress(json.dumps(processed_content).encode("utf-8"))
storage.put_object(processed_object_name, processed_data)
processed_message = message.copy()
processed_message["processed"] = True
processed_message["processor_message"] = "This message was processed by the dummy processor"
logger.info(f"Finished processing message. Result: {processed_message}")
return processed_message
async def test_rabbitmq_handler() -> None:
tenant_service_url = settings.storage.tenant_server.endpoint
config = RabbitMQConfig(
host=settings.rabbitmq.host,
port=settings.rabbitmq.port,
username=settings.rabbitmq.username,
password=settings.rabbitmq.password,
heartbeat=settings.rabbitmq.heartbeat,
input_queue_prefix=settings.rabbitmq.service_request_queue_prefix,
tenant_exchange_name=settings.rabbitmq.service_response_queue_prefix,
service_request_exchange_name=settings.rabbitmq.service_request_exchange_name,
service_response_exchange_name=settings.rabbitmq.service_response_exchange_name,
service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name,
queue_expiration_time=settings.rabbitmq.queue_expiration_time,
)
handler = RabbitMQHandler(config, tenant_service_url, dummy_message_processor)
await handler.connect()
await handler.setup_exchanges()
await handler.initialize_tenant_queues()
await handler.setup_tenant_queue()
tenant_id = "test_tenant"
# Test tenant creation
create_message = {"tenantId": tenant_id}
await handler.tenant_exchange.publish(
Message(body=json.dumps(create_message).encode()), routing_key="tenant.create"
)
logger.info(f"Sent create tenant message for {tenant_id}")
await asyncio.sleep(2) # Wait for queue creation
# Test service request
service_request = {
"tenantId": tenant_id,
"dossierId": "dossier",
"fileId": "file",
"targetFileExtension": "json.gz",
"responseFileExtension": "result.json.gz",
}
await handler.input_exchange.publish(Message(body=json.dumps(service_request).encode()), routing_key=tenant_id)
logger.info(f"Sent service request for {tenant_id}")
await asyncio.sleep(5) # Wait for message processing
# Test tenant deletion
delete_message = {"tenantId": tenant_id}
await handler.tenant_exchange.publish(
Message(body=json.dumps(delete_message).encode()), routing_key="tenant.delete"
)
logger.info(f"Sent delete tenant message for {tenant_id}")
await asyncio.sleep(2) # Wait for queue deletion
await handler.connection.close()
if __name__ == "__main__":
asyncio.run(test_rabbitmq_handler())
# TODO: purge_queues

View File

@ -0,0 +1,153 @@
import asyncio
import gzip
import json
from operator import itemgetter
from typing import Any, Dict
from aio_pika import Message
from aio_pika.abc import AbstractIncomingMessage
from kn_utils.logging import logger
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
from pyinfra.queue.async_tenants_v2 import RabbitMQConfig, RabbitMQHandler
from pyinfra.storage.storages.s3 import S3Storage, get_s3_storage_from_settings
settings = load_settings(local_pyinfra_root_path / "config/")
async def dummy_message_processor(message: Dict[str, Any]) -> Dict[str, Any]:
logger.info(f"Processing message: {message}")
# await asyncio.sleep(1) # Simulate processing time
storage = get_s3_storage_from_settings(settings)
tenant_id, dossier_id, file_id = itemgetter("tenantId", "dossierId", "fileId")(message)
suffix = message["responseFileExtension"]
object_name = f"{tenant_id}/{dossier_id}/{file_id}.{message['targetFileExtension']}"
original_content = json.loads(gzip.decompress(storage.get_object(object_name)))
processed_content = {
"processedPages": original_content["numberOfPages"],
"processedSectionTexts": f"Processed: {original_content['sectionTexts']}",
}
processed_object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}"
processed_data = gzip.compress(json.dumps(processed_content).encode("utf-8"))
storage.put_object(processed_object_name, processed_data)
processed_message = message.copy()
processed_message["processed"] = True
processed_message["processor_message"] = "This message was processed by the dummy processor"
logger.info(f"Finished processing message. Result: {processed_message}")
return processed_message
async def on_response_message_callback(storage: S3Storage):
async def on_message(message: AbstractIncomingMessage) -> None:
async with message.process(ignore_processed=True):
if not message.body:
raise ValueError
response = json.loads(message.body)
logger.info(f"Received {response}")
logger.info(f"Message headers: {message.properties.headers}")
await message.ack()
tenant_id, dossier_id, file_id = itemgetter("tenantId", "dossierId", "fileId")(response)
suffix = response["responseFileExtension"]
result = storage.get_object(f"{tenant_id}/{dossier_id}/{file_id}.{suffix}")
result = json.loads(gzip.decompress(result))
logger.info(f"Contents of result on storage: {result}")
return on_message
def upload_json_and_make_message_body(tenant_id: str):
dossier_id, file_id, suffix = "dossier", "file", "json.gz"
content = {
"numberOfPages": 7,
"sectionTexts": "data",
}
object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}"
data = gzip.compress(json.dumps(content).encode("utf-8"))
storage = get_s3_storage_from_settings(settings)
if not storage.has_bucket():
storage.make_bucket()
storage.put_object(object_name, data)
message_body = {
"tenantId": tenant_id,
"dossierId": dossier_id,
"fileId": file_id,
"targetFileExtension": suffix,
"responseFileExtension": f"result.{suffix}",
}
return message_body, storage
async def test_rabbitmq_handler() -> None:
tenant_service_url = settings.storage.tenant_server.endpoint
config = RabbitMQConfig(
host=settings.rabbitmq.host,
port=settings.rabbitmq.port,
username=settings.rabbitmq.username,
password=settings.rabbitmq.password,
heartbeat=settings.rabbitmq.heartbeat,
input_queue_prefix=settings.rabbitmq.service_request_queue_prefix,
tenant_exchange_name=settings.rabbitmq.service_response_queue_prefix,
service_request_exchange_name=settings.rabbitmq.service_request_exchange_name,
service_response_exchange_name=settings.rabbitmq.service_response_exchange_name,
service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name,
queue_expiration_time=settings.rabbitmq.queue_expiration_time,
)
handler = RabbitMQHandler(config, tenant_service_url, dummy_message_processor)
await handler.connect()
await handler.setup_exchanges()
# await handler.initialize_tenant_queues()
# await handler.setup_tenant_queue()
# for queue in handler.tenant_queues.values():
# await queue.purge()
tenant_id = "test_tenant"
# Test tenant creation
create_message = {"tenantId": tenant_id}
await handler.tenant_exchange.publish(
Message(body=json.dumps(create_message).encode()), routing_key="tenant.created"
)
logger.info(f"Sent create tenant message for {tenant_id}")
await asyncio.sleep(2) # Wait for queue creation
# Prepare service request
service_request, storage = upload_json_and_make_message_body(tenant_id)
# Test service request
await handler.input_exchange.publish(Message(body=json.dumps(service_request).encode()), routing_key=tenant_id)
logger.info(f"Sent service request for {tenant_id}")
await asyncio.sleep(5) # Wait for message processing
# Consume service request
response_queue = await handler.channel.declare_queue(name=f"response_queue_{tenant_id}")
await response_queue.bind(exchange=handler.output_exchange, routing_key=tenant_id)
callback = await on_response_message_callback(storage)
await response_queue.consume(callback=callback)
await asyncio.sleep(5) # Wait for message processing
# Test tenant deletion
delete_message = {"tenantId": tenant_id}
await handler.tenant_exchange.publish(
Message(body=json.dumps(delete_message).encode()), routing_key="tenant.delete"
)
logger.info(f"Sent delete tenant message for {tenant_id}")
await asyncio.sleep(2) # Wait for queue deletion
await handler.connection.close()
if __name__ == "__main__":
asyncio.run(test_rabbitmq_handler())

View File

@ -2,14 +2,14 @@ import asyncio
import gzip
import json
import time
from aio_pika.abc import AbstractIncomingMessage
from operator import itemgetter
from aio_pika.abc import AbstractIncomingMessage
from kn_utils.logging import logger
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
from pyinfra.queue.async_tenants import AsyncQueueManager
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
from pyinfra.storage.storages.s3 import S3Storage
from pyinfra.storage.storages.s3 import S3Storage, get_s3_storage_from_settings
settings = load_settings(local_pyinfra_root_path / "config/")
@ -57,6 +57,7 @@ def on_message_callback(storage: S3Storage):
result = storage.get_object(f"{tenant_id}/{dossier_id}/{file_id}.{suffix}")
result = json.loads(gzip.decompress(result))
logger.info(f"Contents of result on storage: {result}")
return on_message
@ -98,9 +99,7 @@ async def send_service_request(queue_manager: AsyncQueueManager, tenant_id: str)
if __name__ == "__main__":
# tenant_ids = ["a", "b", "c", "d"]
queue_manager = AsyncQueueManager(settings)
# asyncio.run(send_tenant_event(queue_manager, "test", "create"))
asyncio.run(send_service_request(queue_manager,"redaction"))
# asyncio.run(send_tenant_event(AsyncQueueManager(settings), "test_1", "create"))
asyncio.run(send_service_request(AsyncQueueManager(settings), "redaction"))
# asyncio.run(consume_service_request(AsyncQueueManager(settings),"redaction"))

View File

@ -3,6 +3,7 @@ import json
import time
from operator import itemgetter
from threading import Thread
from kn_utils.logging import logger
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
@ -67,9 +68,7 @@ def send_service_request(tenant_id: str):
storage = get_s3_storage_from_settings(settings)
for method_frame, properties, body in queue_manager.channel.consume(
queue=queue_name, inactivity_timeout=15
):
for method_frame, properties, body in queue_manager.channel.consume(queue=queue_name, inactivity_timeout=15):
if not body:
break
response = json.loads(body)
@ -87,13 +86,15 @@ def send_service_request(tenant_id: str):
if __name__ == "__main__":
tenant_ids = ["a", "b", "c", "d"]
import uuid
for tenant in tenant_ids:
unique_ids = [str(uuid.uuid4()) for _ in range(100)]
for tenant in unique_ids:
send_tenant_event(tenant_id=tenant, event_type="create")
for tenant in tenant_ids:
send_service_request(tenant_id=tenant)
# for tenant in tenant_ids:
# send_service_request(tenant_id=tenant)
for tenant in tenant_ids:
send_tenant_event(tenant_id=tenant, event_type="delete")
# for tenant in tenant_ids:
# send_tenant_event(tenant_id=tenant, event_type="delete")