feat(RabbitMQHandler): add async test class
This commit is contained in:
parent
aa23894858
commit
abde776cd1
166
scripts/test_async_class.py
Normal file
166
scripts/test_async_class.py
Normal file
@ -0,0 +1,166 @@
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from aio_pika import connect_robust, ExchangeType, Message
|
||||
from aio_pika.abc import AbstractIncomingMessage
|
||||
import json
|
||||
from logging import getLogger
|
||||
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
|
||||
import requests
|
||||
|
||||
logger = getLogger(__name__)
|
||||
logger.setLevel("DEBUG")
|
||||
|
||||
settings = load_settings(local_pyinfra_root_path / "config/")
|
||||
|
||||
|
||||
class RabbitMQHandler:
|
||||
def __init__(self, connection_params, tenant_service_url):
|
||||
self.connection_params = connection_params
|
||||
self.tenant_service_url = tenant_service_url
|
||||
# TODO: remove hardcoded values
|
||||
self.input_queue_prefix = "service_request_queue"
|
||||
self.tenant_exchange_name = "tenants-exchange"
|
||||
self.service_request_exchange_name = "service_request_exchange" # INPUT
|
||||
self.service_response_exchange_name = "service_response_exchange" # OUTPUT
|
||||
self.service_dead_letter_queue_name = "service_dlq"
|
||||
self.connection = None
|
||||
self.channel = None
|
||||
self.tenant_exchange = None
|
||||
self.input_exchange = None
|
||||
self.output_exchange = None
|
||||
self.tenant_queues = {}
|
||||
|
||||
async def connect(self):
|
||||
self.connection = await connect_robust(**self.connection_params)
|
||||
self.channel = await self.connection.channel()
|
||||
|
||||
# Declare exchanges
|
||||
self.tenant_exchange = await self.channel.declare_exchange(
|
||||
self.tenant_exchange_name, ExchangeType.TOPIC, durable=True
|
||||
)
|
||||
self.input_exchange = await self.channel.declare_exchange(
|
||||
self.service_request_exchange_name, ExchangeType.DIRECT, durable=True
|
||||
)
|
||||
self.output_exchange = await self.channel.declare_exchange(
|
||||
self.service_response_exchange_name, ExchangeType.DIRECT, durable=True
|
||||
)
|
||||
|
||||
async def setup_tenant_queue(self):
|
||||
queue = await self.channel.declare_queue(
|
||||
"tenant_queue",
|
||||
durable=True,
|
||||
arguments={
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.service_dead_letter_queue_name,
|
||||
},
|
||||
)
|
||||
await queue.bind(self.tenant_exchange, routing_key="tenant.*")
|
||||
await queue.consume(self.process_tenant_message)
|
||||
|
||||
async def process_tenant_message(self, message: AbstractIncomingMessage):
|
||||
async with message.process():
|
||||
message_body = json.loads(message.body.decode())
|
||||
print(message_body)
|
||||
logger.debug(f"input message: {message_body}")
|
||||
tenant_id = message_body["queue_name"]
|
||||
routing_key = message.routing_key
|
||||
|
||||
if routing_key == "tenant.create":
|
||||
await self.create_tenant_queues(tenant_id)
|
||||
elif routing_key == "tenant.delete":
|
||||
await self.delete_tenant_queues(tenant_id)
|
||||
|
||||
async def create_tenant_queues(self, tenant_id):
|
||||
# Create and bind input queue
|
||||
queue_name = f"{self.input_queue_prefix}_{tenant_id}"
|
||||
print(f"queue declared: {queue_name}")
|
||||
input_queue = await self.channel.declare_queue(
|
||||
queue_name,
|
||||
durable=True,
|
||||
arguments={
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.service_dead_letter_queue_name,
|
||||
"x-expires": self.queue_expiration_time,
|
||||
"x-max-priority": 2,
|
||||
},
|
||||
)
|
||||
await input_queue.bind(self.input_exchange, routing_key=tenant_id)
|
||||
await input_queue.consume(self.process_input_message)
|
||||
|
||||
# Store queues for later use
|
||||
self.tenant_queues[tenant_id] = input_queue
|
||||
print(f"Created queues for tenant {tenant_id}")
|
||||
|
||||
async def delete_tenant_queues(self, tenant_id):
|
||||
if tenant_id in self.tenant_queues:
|
||||
input_queue = self.tenant_queues[tenant_id]
|
||||
await input_queue.delete()
|
||||
del self.tenant_queues[tenant_id]
|
||||
print(f"Deleted queues for tenant {tenant_id}")
|
||||
|
||||
async def process_input_message(self, message: AbstractIncomingMessage):
|
||||
async with message.process():
|
||||
message_body = json.loads(message.body.decode())
|
||||
logger.debug(f"input message: {message_body}")
|
||||
# Process the incoming message
|
||||
processed_content = f"Processed: {message_body}"
|
||||
|
||||
# TODO: add additional processing logic here
|
||||
# ...
|
||||
|
||||
# Publish to the output queue
|
||||
tenant_id = message.routing_key
|
||||
await self.output_exchange.publish(Message(body=processed_content.encode()), routing_key=tenant_id)
|
||||
|
||||
# FIXME: coroutine error
|
||||
async def fetch_active_tenants(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self.tenant_service_url) as response:
|
||||
if response.status == 200 and response.headers["content-type"].lower() == "application/json":
|
||||
tenants = {await tenant["tenantId"] for tenant in response.json()}
|
||||
return await tenants
|
||||
else:
|
||||
print(f"Failed to fetch active tenants. Status: {response.status}")
|
||||
return set()
|
||||
|
||||
# TODO: remove after fetch_active_tenants is fixed
|
||||
def get_initial_tenant_ids(self) -> set:
|
||||
response = requests.get(self.tenant_service_url, timeout=10)
|
||||
response.raise_for_status() # Raise an HTTPError for bad responses
|
||||
|
||||
if response.headers["content-type"].lower() == "application/json":
|
||||
tenants = {tenant["tenantId"] for tenant in response.json()}
|
||||
return tenants
|
||||
return set()
|
||||
|
||||
async def initialize_tenant_queues(self):
|
||||
active_tenants = self.get_initial_tenant_ids()
|
||||
for tenant_id in active_tenants:
|
||||
await self.create_tenant_queues(tenant_id)
|
||||
|
||||
async def run(self):
|
||||
await self.connect()
|
||||
await self.initialize_tenant_queues()
|
||||
await self.setup_tenant_queue()
|
||||
print("RabbitMQ handler is running. Press CTRL+C to exit.")
|
||||
try:
|
||||
await asyncio.Future() # Run forever
|
||||
finally:
|
||||
await self.connection.close()
|
||||
|
||||
|
||||
async def main():
|
||||
connection_params = {
|
||||
"host": settings.rabbitmq.host,
|
||||
"port": settings.rabbitmq.port,
|
||||
"login": settings.rabbitmq.username,
|
||||
"password": settings.rabbitmq.password,
|
||||
"client_properties": {"heartbeat": settings.rabbitmq.heartbeat},
|
||||
}
|
||||
tenant_service_url = "http://localhost:8080/internal/tenants"
|
||||
handler = RabbitMQHandler(connection_params, tenant_service_url)
|
||||
await handler.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Loading…
x
Reference in New Issue
Block a user