pyinfra/scripts/test_async_class.py
2024-07-11 14:46:41 -04:00

300 lines
12 KiB
Python

import asyncio
import gzip
from operator import itemgetter
from typing import Any, Callable, Dict, Set
import aiohttp
from aio_pika import connect_robust, ExchangeType, Message, IncomingMessage
from aio_pika.abc import AbstractIncomingMessage
import json
from logging import getLogger
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
logger = getLogger(__name__)
logger.setLevel("DEBUG")
settings = load_settings(local_pyinfra_root_path / "config/")
class RabbitMQHandler:
def __init__(self, connection_params, tenant_service_url, message_processor):
self.connection_params = connection_params
self.tenant_service_url = tenant_service_url
# TODO: remove hardcoded values
self.input_queue_prefix = "service_request_queue"
self.tenant_exchange_name = "tenants-exchange"
self.service_request_exchange_name = "service_request_exchange" # INPUT
self.service_response_exchange_name = "service_response_exchange" # OUTPUT
self.service_dead_letter_queue_name = "service_dlq"
self.queue_expiration_time = 300000
self.message_processor = message_processor
self.connection = None
self.channel = None
self.tenant_exchange = None
self.input_exchange = None
self.output_exchange = None
self.tenant_queues = {}
async def connect(self):
self.connection = await connect_robust(**self.connection_params)
self.channel = await self.connection.channel()
# Declare exchanges
self.tenant_exchange = await self.channel.declare_exchange(
self.tenant_exchange_name, ExchangeType.TOPIC, durable=True
)
self.input_exchange = await self.channel.declare_exchange(
self.service_request_exchange_name, ExchangeType.DIRECT, durable=True
)
self.output_exchange = await self.channel.declare_exchange(
self.service_response_exchange_name, ExchangeType.DIRECT, durable=True
)
async def setup_tenant_queue(self):
queue = await self.channel.declare_queue(
"tenant_queue",
durable=True,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.service_dead_letter_queue_name,
},
)
await queue.bind(self.tenant_exchange, routing_key="tenant.*")
await queue.consume(self.process_tenant_message)
async def process_tenant_message(self, message: AbstractIncomingMessage):
async with message.process():
message_body = json.loads(message.body.decode())
print(message_body)
logger.debug(f"input message: {message_body}")
tenant_id = message_body["tenantId"]
routing_key = message.routing_key
if routing_key == "tenant.create":
await self.create_tenant_queues(tenant_id)
elif routing_key == "tenant.delete":
await self.delete_tenant_queues(tenant_id)
async def create_tenant_queues(self, tenant_id):
# Create and bind input queue
queue_name = f"{self.input_queue_prefix}_{tenant_id}"
print(f"queue declared: {queue_name}")
input_queue = await self.channel.declare_queue(
queue_name,
durable=True,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.service_dead_letter_queue_name,
"x-expires": self.queue_expiration_time,
"x-max-priority": 2,
},
)
await input_queue.bind(self.input_exchange, routing_key=tenant_id)
await input_queue.consume(lambda msg: self.process_input_message(msg, self.message_processor))
# Store queues for later use
self.tenant_queues[tenant_id] = input_queue
print(f"Created queues for tenant {tenant_id}")
async def delete_tenant_queues(self, tenant_id):
if tenant_id in self.tenant_queues:
input_queue = self.tenant_queues[tenant_id]
await input_queue.delete()
del self.tenant_queues[tenant_id]
print(f"Deleted queues for tenant {tenant_id}")
async def process_input_message(
self, message: IncomingMessage, message_processor: Callable[[Dict[str, Any]], Dict[str, Any]]
) -> None:
async with message.process():
try:
tenant_id = message.routing_key
message_body = json.loads(message.body.decode("utf-8"))
# Extract headers
filtered_message_headers = (
{k: v for k, v in message.headers.items() if k.lower().startswith("x-")} if message.headers else {}
)
logger.debug(f"Processing message with {filtered_message_headers=}.")
# Process the message
message_body.update(filtered_message_headers)
# Run the message processor in a separate thread to avoid blocking
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, message_processor, message_body)
if result:
# Publish the result to the output exchange
self.publish_to_output_exchange(tenant_id, result, filtered_message_headers)
except json.JSONDecodeError:
logger.error(f"Invalid JSON in input message: {message.body}")
# Message will be nacked and sent to dead-letter queue
except FileNotFoundError as e:
logger.warning(f"{e}, declining message.")
# Message will be nacked and sent to dead-letter queue
except Exception as e:
logger.error(f"Error processing input message: {e}", exc_info=True)
# Message will be nacked and sent to dead-letter queue
raise
async def publish_to_output_exchange(self, tenant_id: str, result: Dict[str, Any], headers: Dict[str, Any]):
await self.output_exchange.publish(
Message(body=json.dumps(result).encode(), headers=headers),
routing_key=tenant_id,
)
logger.info(f"Published result to queue {tenant_id}.")
async def fetch_active_tenants(self) -> Set[str]:
try:
async with aiohttp.ClientSession() as session:
async with session.get(self.tenant_service_url) as response:
if response.status == 200:
data = await response.json()
return {tenant["tenantId"] for tenant in data}
else:
logger.error(f"Failed to fetch active tenants. Status: {response.status}")
return set()
except aiohttp.ClientError as e:
logger.error(f"Error fetching active tenants: {e}")
return set()
async def initialize_tenant_queues(self):
active_tenants = await self.fetch_active_tenants()
for tenant_id in active_tenants:
await self.create_tenant_queues(tenant_id)
async def run(self):
await self.connect()
await self.initialize_tenant_queues()
await self.setup_tenant_queue()
print("RabbitMQ handler is running. Press CTRL+C to exit.")
try:
await asyncio.Future() # Run forever
finally:
await self.connection.close()
async def main():
connection_params = {
"host": settings.rabbitmq.host,
"port": settings.rabbitmq.port,
"login": settings.rabbitmq.username,
"password": settings.rabbitmq.password,
"client_properties": {"heartbeat": settings.rabbitmq.heartbeat},
}
tenant_service_url = "http://localhost:8080/internal/tenants"
handler = RabbitMQHandler(connection_params, tenant_service_url, dummy_message_processor)
await handler.run()
########################################################################
def upload_json_and_make_message_body(tenant_id: str) -> Dict[str, Any]:
dossier_id, file_id, suffix = "dossier", "file", "json.gz"
content = {
"numberOfPages": 7,
"sectionTexts": "data",
}
object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}"
data = gzip.compress(json.dumps(content).encode("utf-8"))
storage = get_s3_storage_from_settings(settings)
if not storage.has_bucket():
storage.make_bucket()
storage.put_object(object_name, data)
message_body = {
"tenantId": tenant_id,
"dossierId": dossier_id,
"fileId": file_id,
"targetFileExtension": suffix,
"responseFileExtension": f"result.{suffix}",
}
return message_body
def tenant_event_message(tenant_id: str) -> Dict[str, str]:
return {"tenantId": tenant_id}
async def dummy_message_processor(message: Dict[str, Any]) -> Dict[str, Any]:
logger.info(f"Processing message: {message}")
await asyncio.sleep(1) # Simulate processing time
storage = get_s3_storage_from_settings(settings)
tenant_id, dossier_id, file_id = itemgetter("tenantId", "dossierId", "fileId")(message)
suffix = message["responseFileExtension"]
# Simulate processing by modifying the original content
object_name = f"{tenant_id}/{dossier_id}/{file_id}.{message['targetFileExtension']}"
original_content = json.loads(gzip.decompress(storage.get_object(object_name)))
processed_content = {
"processedPages": original_content["numberOfPages"],
"processedSectionTexts": f"Processed: {original_content['sectionTexts']}",
}
# Save processed content
processed_object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}"
processed_data = gzip.compress(json.dumps(processed_content).encode("utf-8"))
storage.put_object(processed_object_name, processed_data)
processed_message = message.copy()
processed_message["processed"] = True
processed_message["processor_message"] = "This message was processed by the dummy processor"
logger.info(f"Finished processing message. Result: {processed_message}")
return processed_message
async def test_rabbitmq_handler():
connection_params = {
"host": settings.rabbitmq.host,
"port": settings.rabbitmq.port,
"login": settings.rabbitmq.username,
"password": settings.rabbitmq.password,
"client_properties": {"heartbeat": settings.rabbitmq.heartbeat},
}
tenant_service_url = "http://localhost:8080/internal/tenants"
handler = RabbitMQHandler(connection_params, tenant_service_url, dummy_message_processor)
await handler.connect()
await handler.setup_tenant_queue()
# Test tenant creation
tenant_id = "test_tenant"
create_message = tenant_event_message(tenant_id)
await handler.tenant_exchange.publish(
Message(body=json.dumps(create_message).encode()), routing_key="tenant.create"
)
logger.info(f"Sent create tenant message for {tenant_id}")
# Wait for tenant queue creation
await asyncio.sleep(2)
# Test service request
service_request = upload_json_and_make_message_body(tenant_id)
await handler.input_exchange.publish(Message(body=json.dumps(service_request).encode()), routing_key=tenant_id)
logger.info(f"Sent service request for {tenant_id}")
# Wait for message processing
await asyncio.sleep(5)
# Test tenant deletion
delete_message = tenant_event_message(tenant_id)
await handler.tenant_exchange.publish(
Message(body=json.dumps(delete_message).encode()), routing_key="tenant.delete"
)
logger.info(f"Sent delete tenant message for {tenant_id}")
# Wait for tenant queue deletion
await asyncio.sleep(2)
await handler.connection.close()
if __name__ == "__main__":
asyncio.run(test_rabbitmq_handler())