pyinfra/pyinfra/queue/async_manager.py
2024-07-25 14:45:19 +02:00

259 lines
10 KiB
Python

import asyncio
import json
import signal
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Set
import aiohttp
from aio_pika import ExchangeType, IncomingMessage, Message, connect_robust
from aio_pika.abc import (
AbstractChannel,
AbstractConnection,
AbstractExchange,
AbstractIncomingMessage,
)
from kn_utils.logging import logger
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
@dataclass
class RabbitMQConfig:
host: str
port: int
username: str
password: str
heartbeat: int
input_queue_prefix: str
tenant_event_queue_suffix: str
tenant_exchange_name: str
service_request_exchange_name: str
service_response_exchange_name: str
service_dead_letter_queue_name: str
queue_expiration_time: int
pod_name: str
connection_params: Dict[str, object] = field(init=False)
def __post_init__(self):
self.connection_params = {
"host": self.host,
"port": self.port,
"login": self.username,
"password": self.password,
"client_properties": {"heartbeat": self.heartbeat},
}
class AsyncQueueManager:
def __init__(
self,
config: RabbitMQConfig,
tenant_service_url: str,
message_processor: Callable[[Dict[str, Any]], Dict[str, Any]],
):
self.config = config
self.tenant_service_url = tenant_service_url
self.message_processor = message_processor
self.connection: AbstractConnection | None = None
self.channel: AbstractChannel | None = None
self.tenant_exchange: AbstractExchange | None = None
self.input_exchange: AbstractExchange | None = None
self.output_exchange: AbstractExchange | None = None
self.tenant_queues: Dict[str, AbstractChannel] = {}
async def connect(self) -> None:
self.connection = await connect_robust(**self.config.connection_params)
self.channel = await self.connection.channel()
await self.channel.set_qos(prefetch_count=1)
async def is_ready(self) -> bool:
await self.connect()
return await self.channel.is_open
async def setup_exchanges(self) -> None:
self.tenant_exchange = await self.channel.declare_exchange(
self.config.tenant_exchange_name, ExchangeType.TOPIC, durable=True
)
self.input_exchange = await self.channel.declare_exchange(
self.config.service_request_exchange_name, ExchangeType.DIRECT, durable=True
)
self.output_exchange = await self.channel.declare_exchange(
self.config.service_response_exchange_name, ExchangeType.DIRECT, durable=True
)
async def setup_tenant_queue(self) -> None:
queue = await self.channel.declare_queue(
f"{self.config.pod_name}_{self.config.tenant_event_queue_suffix}",
durable=True,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.config.service_dead_letter_queue_name,
"x-expires": self.config.queue_expiration_time,
"x-max-priority": 2,
},
)
await queue.bind(self.tenant_exchange, routing_key="tenant.*")
await queue.consume(self.process_tenant_message)
async def process_tenant_message(self, message: AbstractIncomingMessage) -> None:
async with message.process():
message_body = json.loads(message.body.decode())
logger.debug(f"Tenant message received: {message_body}")
tenant_id = message_body["tenantId"]
routing_key = message.routing_key
if routing_key == "tenant.created":
await self.create_tenant_queues(tenant_id)
elif routing_key == "tenant.delete":
await self.delete_tenant_queues(tenant_id)
async def create_tenant_queues(self, tenant_id: str) -> None:
queue_name = f"{self.config.input_queue_prefix}_{tenant_id}"
logger.info(f"Declaring queue: {queue_name}")
input_queue = await self.channel.declare_queue(
queue_name,
durable=True,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.config.service_dead_letter_queue_name,
"x-expires": self.config.queue_expiration_time,
"x-max-priority": 2,
},
)
await input_queue.bind(self.input_exchange, routing_key=tenant_id)
await input_queue.consume(self.process_input_message)
self.tenant_queues[tenant_id] = input_queue
logger.info(f"Created queues for tenant {tenant_id}")
async def delete_tenant_queues(self, tenant_id: str) -> None:
if tenant_id in self.tenant_queues:
# somehow queue.delete() does not work here
await self.channel.queue_delete(f"{self.config.input_queue_prefix}_{tenant_id}")
del self.tenant_queues[tenant_id]
logger.info(f"Deleted queues for tenant {tenant_id}")
async def process_input_message(self, message: IncomingMessage) -> None:
async def process_message_body_and_await_result(unpacked_message_body):
return self.message_processor(unpacked_message_body)
async with message.process(ignore_processed=True):
if message.redelivered:
logger.warning(f"Declining message with {message.delivery_tag=} due to it being redelivered.")
await message.nack(requeue=False)
return
if message.body.decode("utf-8") == "STOP":
logger.info("Received stop signal, stopping consumption...")
await message.ack()
# TODO: shutdown is probably not the right call here - align w/ Dev what should happen on stop signal
await self.shutdown()
return
try:
tenant_id = message.routing_key
filtered_message_headers = (
{k: v for k, v in message.headers.items() if k.lower().startswith("x-")} if message.headers else {}
)
logger.debug(f"Processing message with {filtered_message_headers=}.")
result: dict = await (
process_message_body_and_await_result({**json.loads(message.body), **filtered_message_headers})
or {}
)
if result:
await self.publish_to_output_exchange(tenant_id, result, filtered_message_headers)
await message.ack()
logger.debug(f"Message with {message.delivery_tag=} acknowledged.")
else:
raise ValueError(f"Could not process message with {message.body=}.")
except json.JSONDecodeError:
await message.nack(requeue=False)
logger.error(f"Invalid JSON in input message: {message.body}")
except FileNotFoundError as e:
logger.warning(f"{e}, declining message with {message.delivery_tag=}.")
await message.nack(requeue=False)
except Exception as e:
await message.nack(requeue=False)
logger.error(f"Error processing input message: {e}", exc_info=True)
raise
async def publish_to_output_exchange(self, tenant_id: str, result: Dict[str, Any], headers: Dict[str, Any]) -> None:
await self.output_exchange.publish(
Message(body=json.dumps(result).encode(), headers=headers),
routing_key=tenant_id,
)
logger.info(f"Published result to queue {tenant_id}.")
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential_jitter(initial=1, max=10),
retry=retry_if_exception_type(aiohttp.ClientResponseError),
reraise=True,
)
async def fetch_active_tenants(self) -> Set[str]:
async with aiohttp.ClientSession() as session:
async with session.get(self.tenant_service_url) as response:
response.raise_for_status()
if response.headers["content-type"].lower() == "application/json":
data = await response.json()
return {tenant["tenantId"] for tenant in data}
else:
logger.error(
f"Failed to fetch active tenants. Content type is not JSON: {response.headers['content-type'].lower()}"
)
return set()
async def initialize_tenant_queues(self) -> None:
try:
active_tenants = await self.fetch_active_tenants()
except aiohttp.ClientResponseError:
logger.warning("API calls to tenant server failed. No tenant queues initialized.")
active_tenants = set()
for tenant_id in active_tenants:
await self.create_tenant_queues(tenant_id)
async def run(self) -> None:
stop = asyncio.Event()
def signal_handler(*_):
logger.info("Signal received, shutting down...")
stop.set()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, signal_handler)
try:
await self.connect()
await self.setup_exchanges()
await self.initialize_tenant_queues()
await self.setup_tenant_queue()
logger.info("RabbitMQ handler is running. Press CTRL+C to exit.")
await stop.wait() # Run until stop signal received
except asyncio.CancelledError:
logger.warning("Operation cancelled.")
except Exception as e:
logger.error(f"An error occurred: {e}", exc_info=True)
finally:
await self.shutdown()
async def shutdown(self) -> None:
logger.info("Shutting down RabbitMQ handler...")
if self.channel:
await self.channel.close()
if self.connection:
await self.connection.close()
logger.info("RabbitMQ handler shut down successfully.")