chore: clean up + improve robustness
This commit is contained in:
parent
7559118822
commit
f9aec74d55
@ -4,69 +4,92 @@ from operator import itemgetter
|
||||
from typing import Any, Callable, Dict, Set
|
||||
import aiohttp
|
||||
from aio_pika import connect_robust, ExchangeType, Message, IncomingMessage
|
||||
from aio_pika.abc import AbstractIncomingMessage
|
||||
from aio_pika.abc import AbstractIncomingMessage, AbstractChannel, AbstractConnection, AbstractExchange
|
||||
import json
|
||||
from logging import getLogger
|
||||
from kn_utils.logging import logger
|
||||
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
|
||||
|
||||
logger = getLogger(__name__)
|
||||
logger.setLevel("DEBUG")
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
settings = load_settings(local_pyinfra_root_path / "config/")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RabbitMQConfig:
|
||||
host: str
|
||||
port: int
|
||||
username: str
|
||||
password: str
|
||||
heartbeat: int
|
||||
input_queue_prefix: str
|
||||
tenant_exchange_name: str
|
||||
service_request_exchange_name: str
|
||||
service_response_exchange_name: str
|
||||
service_dead_letter_queue_name: str
|
||||
queue_expiration_time: int
|
||||
|
||||
connection_params: Dict[str, object] = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.connection_params = {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"login": self.username,
|
||||
"password": self.password,
|
||||
"client_properties": {"heartbeat": self.heartbeat},
|
||||
}
|
||||
|
||||
|
||||
class RabbitMQHandler:
|
||||
def __init__(self, connection_params, tenant_service_url, message_processor):
|
||||
self.connection_params = connection_params
|
||||
def __init__(
|
||||
self,
|
||||
config: RabbitMQConfig,
|
||||
tenant_service_url: str,
|
||||
message_processor: Callable[[Dict[str, Any]], Dict[str, Any]],
|
||||
):
|
||||
self.config = config
|
||||
self.tenant_service_url = tenant_service_url
|
||||
# TODO: remove hardcoded values
|
||||
self.input_queue_prefix = "service_request_queue"
|
||||
self.tenant_exchange_name = "tenants-exchange"
|
||||
self.service_request_exchange_name = "service_request_exchange" # INPUT
|
||||
self.service_response_exchange_name = "service_response_exchange" # OUTPUT
|
||||
self.service_dead_letter_queue_name = "service_dlq"
|
||||
self.queue_expiration_time = 300000
|
||||
self.message_processor = message_processor
|
||||
self.connection = None
|
||||
self.channel = None
|
||||
self.tenant_exchange = None
|
||||
self.input_exchange = None
|
||||
self.output_exchange = None
|
||||
self.tenant_queues = {}
|
||||
|
||||
async def connect(self):
|
||||
self.connection = await connect_robust(**self.connection_params)
|
||||
self.connection: AbstractConnection | None = None
|
||||
self.channel: AbstractChannel | None = None
|
||||
self.tenant_exchange: AbstractExchange | None = None
|
||||
self.input_exchange: AbstractExchange | None = None
|
||||
self.output_exchange: AbstractExchange | None = None
|
||||
self.tenant_queues: Dict[str, AbstractChannel] = {}
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.connection = await connect_robust(**self.config.connection_params)
|
||||
self.channel = await self.connection.channel()
|
||||
await self.channel.set_qos(prefetch_count=1)
|
||||
|
||||
# Declare exchanges
|
||||
async def setup_exchanges(self):
|
||||
self.tenant_exchange = await self.channel.declare_exchange(
|
||||
self.tenant_exchange_name, ExchangeType.TOPIC, durable=True
|
||||
self.config.tenant_exchange_name, ExchangeType.TOPIC, durable=True
|
||||
)
|
||||
self.input_exchange = await self.channel.declare_exchange(
|
||||
self.service_request_exchange_name, ExchangeType.DIRECT, durable=True
|
||||
self.config.service_request_exchange_name, ExchangeType.DIRECT, durable=True
|
||||
)
|
||||
self.output_exchange = await self.channel.declare_exchange(
|
||||
self.service_response_exchange_name, ExchangeType.DIRECT, durable=True
|
||||
self.config.service_response_exchange_name, ExchangeType.DIRECT, durable=True
|
||||
)
|
||||
|
||||
async def setup_tenant_queue(self):
|
||||
async def setup_tenant_queue(self) -> None:
|
||||
queue = await self.channel.declare_queue(
|
||||
"tenant_queue",
|
||||
durable=True,
|
||||
arguments={
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.service_dead_letter_queue_name,
|
||||
"x-dead-letter-routing-key": self.config.service_dead_letter_queue_name,
|
||||
},
|
||||
)
|
||||
await queue.bind(self.tenant_exchange, routing_key="tenant.*")
|
||||
await queue.consume(self.process_tenant_message)
|
||||
|
||||
async def process_tenant_message(self, message: AbstractIncomingMessage):
|
||||
async def process_tenant_message(self, message: AbstractIncomingMessage) -> None:
|
||||
async with message.process():
|
||||
message_body = json.loads(message.body.decode())
|
||||
print(message_body)
|
||||
logger.debug(f"input message: {message_body}")
|
||||
logger.debug(f"Tenant message received: {message_body}")
|
||||
tenant_id = message_body["tenantId"]
|
||||
routing_key = message.routing_key
|
||||
|
||||
@ -75,72 +98,60 @@ class RabbitMQHandler:
|
||||
elif routing_key == "tenant.delete":
|
||||
await self.delete_tenant_queues(tenant_id)
|
||||
|
||||
async def create_tenant_queues(self, tenant_id):
|
||||
# Create and bind input queue
|
||||
queue_name = f"{self.input_queue_prefix}_{tenant_id}"
|
||||
print(f"queue declared: {queue_name}")
|
||||
async def create_tenant_queues(self, tenant_id: str) -> None:
|
||||
queue_name = f"{self.config.input_queue_prefix}_{tenant_id}"
|
||||
logger.info(f"Declaring queue: {queue_name}")
|
||||
input_queue = await self.channel.declare_queue(
|
||||
queue_name,
|
||||
durable=True,
|
||||
arguments={
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.service_dead_letter_queue_name,
|
||||
"x-expires": self.queue_expiration_time,
|
||||
"x-dead-letter-routing-key": self.config.service_dead_letter_queue_name,
|
||||
"x-expires": self.config.queue_expiration_time,
|
||||
"x-max-priority": 2,
|
||||
},
|
||||
)
|
||||
await input_queue.bind(self.input_exchange, routing_key=tenant_id)
|
||||
await input_queue.consume(lambda msg: self.process_input_message(msg, self.message_processor))
|
||||
await input_queue.consume(self.process_input_message)
|
||||
|
||||
# Store queues for later use
|
||||
self.tenant_queues[tenant_id] = input_queue
|
||||
print(f"Created queues for tenant {tenant_id}")
|
||||
logger.info(f"Created queues for tenant {tenant_id}")
|
||||
|
||||
async def delete_tenant_queues(self, tenant_id):
|
||||
async def delete_tenant_queues(self, tenant_id: str) -> None:
|
||||
if tenant_id in self.tenant_queues:
|
||||
input_queue = self.tenant_queues[tenant_id]
|
||||
await input_queue.delete()
|
||||
del self.tenant_queues[tenant_id]
|
||||
print(f"Deleted queues for tenant {tenant_id}")
|
||||
logger.info(f"Deleted queues for tenant {tenant_id}")
|
||||
|
||||
async def process_input_message(
|
||||
self, message: IncomingMessage, message_processor: Callable[[Dict[str, Any]], Dict[str, Any]]
|
||||
) -> None:
|
||||
async def process_input_message(self, message: IncomingMessage) -> None:
|
||||
async with message.process():
|
||||
try:
|
||||
tenant_id = message.routing_key
|
||||
message_body = json.loads(message.body.decode("utf-8"))
|
||||
|
||||
# Extract headers
|
||||
filtered_message_headers = (
|
||||
{k: v for k, v in message.headers.items() if k.lower().startswith("x-")} if message.headers else {}
|
||||
)
|
||||
|
||||
logger.debug(f"Processing message with {filtered_message_headers=}.")
|
||||
|
||||
# Process the message
|
||||
message_body.update(filtered_message_headers)
|
||||
|
||||
# Run the message processor in a separate thread to avoid blocking
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, message_processor, message_body)
|
||||
result = await self.message_processor(message_body)
|
||||
|
||||
if result:
|
||||
# Publish the result to the output exchange
|
||||
self.publish_to_output_exchange(tenant_id, result, filtered_message_headers)
|
||||
await self.publish_to_output_exchange(tenant_id, result, filtered_message_headers)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid JSON in input message: {message.body}")
|
||||
# Message will be nacked and sent to dead-letter queue
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(f"{e}, declining message.")
|
||||
# Message will be nacked and sent to dead-letter queue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing input message: {e}", exc_info=True)
|
||||
# Message will be nacked and sent to dead-letter queue
|
||||
raise
|
||||
|
||||
async def publish_to_output_exchange(self, tenant_id: str, result: Dict[str, Any], headers: Dict[str, Any]):
|
||||
async def publish_to_output_exchange(self, tenant_id: str, result: Dict[str, Any], headers: Dict[str, Any]) -> None:
|
||||
await self.output_exchange.publish(
|
||||
Message(body=json.dumps(result).encode(), headers=headers),
|
||||
routing_key=tenant_id,
|
||||
@ -161,63 +172,26 @@ class RabbitMQHandler:
|
||||
logger.error(f"Error fetching active tenants: {e}")
|
||||
return set()
|
||||
|
||||
async def initialize_tenant_queues(self):
|
||||
async def initialize_tenant_queues(self) -> None:
|
||||
active_tenants = await self.fetch_active_tenants()
|
||||
for tenant_id in active_tenants:
|
||||
await self.create_tenant_queues(tenant_id)
|
||||
|
||||
async def run(self):
|
||||
await self.connect()
|
||||
await self.initialize_tenant_queues()
|
||||
await self.setup_tenant_queue()
|
||||
print("RabbitMQ handler is running. Press CTRL+C to exit.")
|
||||
async def run(self) -> None:
|
||||
try:
|
||||
await self.connect()
|
||||
await self.setup_exchanges()
|
||||
await self.initialize_tenant_queues()
|
||||
await self.setup_tenant_queue()
|
||||
logger.info("RabbitMQ handler is running. Press CTRL+C to exit.")
|
||||
await asyncio.Future() # Run forever
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Shutting down RabbitMQ handler...")
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred: {e}", exc_info=True)
|
||||
finally:
|
||||
await self.connection.close()
|
||||
|
||||
|
||||
async def main():
|
||||
connection_params = {
|
||||
"host": settings.rabbitmq.host,
|
||||
"port": settings.rabbitmq.port,
|
||||
"login": settings.rabbitmq.username,
|
||||
"password": settings.rabbitmq.password,
|
||||
"client_properties": {"heartbeat": settings.rabbitmq.heartbeat},
|
||||
}
|
||||
tenant_service_url = "http://localhost:8080/internal/tenants"
|
||||
handler = RabbitMQHandler(connection_params, tenant_service_url, dummy_message_processor)
|
||||
await handler.run()
|
||||
|
||||
|
||||
########################################################################
|
||||
def upload_json_and_make_message_body(tenant_id: str) -> Dict[str, Any]:
|
||||
dossier_id, file_id, suffix = "dossier", "file", "json.gz"
|
||||
content = {
|
||||
"numberOfPages": 7,
|
||||
"sectionTexts": "data",
|
||||
}
|
||||
|
||||
object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}"
|
||||
data = gzip.compress(json.dumps(content).encode("utf-8"))
|
||||
|
||||
storage = get_s3_storage_from_settings(settings)
|
||||
if not storage.has_bucket():
|
||||
storage.make_bucket()
|
||||
storage.put_object(object_name, data)
|
||||
|
||||
message_body = {
|
||||
"tenantId": tenant_id,
|
||||
"dossierId": dossier_id,
|
||||
"fileId": file_id,
|
||||
"targetFileExtension": suffix,
|
||||
"responseFileExtension": f"result.{suffix}",
|
||||
}
|
||||
return message_body
|
||||
|
||||
|
||||
def tenant_event_message(tenant_id: str) -> Dict[str, str]:
|
||||
return {"tenantId": tenant_id}
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
|
||||
|
||||
async def dummy_message_processor(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@ -228,7 +202,6 @@ async def dummy_message_processor(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
tenant_id, dossier_id, file_id = itemgetter("tenantId", "dossierId", "fileId")(message)
|
||||
suffix = message["responseFileExtension"]
|
||||
|
||||
# Simulate processing by modifying the original content
|
||||
object_name = f"{tenant_id}/{dossier_id}/{file_id}.{message['targetFileExtension']}"
|
||||
original_content = json.loads(gzip.decompress(storage.get_object(object_name)))
|
||||
processed_content = {
|
||||
@ -236,7 +209,6 @@ async def dummy_message_processor(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"processedSectionTexts": f"Processed: {original_content['sectionTexts']}",
|
||||
}
|
||||
|
||||
# Save processed content
|
||||
processed_object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}"
|
||||
processed_data = gzip.compress(json.dumps(processed_content).encode("utf-8"))
|
||||
storage.put_object(processed_object_name, processed_data)
|
||||
@ -249,40 +221,59 @@ async def dummy_message_processor(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return processed_message
|
||||
|
||||
|
||||
async def test_rabbitmq_handler():
|
||||
connection_params = {
|
||||
"host": settings.rabbitmq.host,
|
||||
"port": settings.rabbitmq.port,
|
||||
"login": settings.rabbitmq.username,
|
||||
"password": settings.rabbitmq.password,
|
||||
"client_properties": {"heartbeat": settings.rabbitmq.heartbeat},
|
||||
}
|
||||
tenant_service_url = "http://localhost:8080/internal/tenants"
|
||||
handler = RabbitMQHandler(connection_params, tenant_service_url, dummy_message_processor)
|
||||
async def test_rabbitmq_handler() -> None:
|
||||
tenant_service_url = settings.storage.tenant_server.endpoint
|
||||
|
||||
config = RabbitMQConfig(
|
||||
host=settings.rabbitmq.host,
|
||||
port=settings.rabbitmq.port,
|
||||
username=settings.rabbitmq.username,
|
||||
password=settings.rabbitmq.password,
|
||||
heartbeat=settings.rabbitmq.heartbeat,
|
||||
input_queue_prefix=settings.rabbitmq.service_request_queue_prefix,
|
||||
tenant_exchange_name=settings.rabbitmq.service_response_queue_prefix,
|
||||
service_request_exchange_name=settings.rabbitmq.service_request_exchange_name,
|
||||
service_response_exchange_name=settings.rabbitmq.service_response_exchange_name,
|
||||
service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name,
|
||||
queue_expiration_time=settings.rabbitmq.queue_expiration_time,
|
||||
)
|
||||
|
||||
handler = RabbitMQHandler(config, tenant_service_url, dummy_message_processor)
|
||||
|
||||
await handler.connect()
|
||||
await handler.setup_exchanges()
|
||||
await handler.initialize_tenant_queues()
|
||||
await handler.setup_tenant_queue()
|
||||
|
||||
# Test tenant creation
|
||||
tenant_id = "test_tenant"
|
||||
create_message = tenant_event_message(tenant_id)
|
||||
|
||||
# Test tenant creation
|
||||
create_message = {"tenantId": tenant_id}
|
||||
await handler.tenant_exchange.publish(
|
||||
Message(body=json.dumps(create_message).encode()), routing_key="tenant.create"
|
||||
)
|
||||
logger.info(f"Sent create tenant message for {tenant_id}")
|
||||
await asyncio.sleep(2) # Wait for queue creation
|
||||
|
||||
# Test service request
|
||||
service_request = upload_json_and_make_message_body(tenant_id)
|
||||
service_request = {
|
||||
"tenantId": tenant_id,
|
||||
"dossierId": "dossier",
|
||||
"fileId": "file",
|
||||
"targetFileExtension": "json.gz",
|
||||
"responseFileExtension": "result.json.gz",
|
||||
}
|
||||
await handler.input_exchange.publish(Message(body=json.dumps(service_request).encode()), routing_key=tenant_id)
|
||||
logger.info(f"Sent service request for {tenant_id}")
|
||||
await asyncio.sleep(5) # Wait for message processing
|
||||
|
||||
# Test tenant deletion
|
||||
delete_message = tenant_event_message(tenant_id)
|
||||
delete_message = {"tenantId": tenant_id}
|
||||
await handler.tenant_exchange.publish(
|
||||
Message(body=json.dumps(delete_message).encode()), routing_key="tenant.delete"
|
||||
)
|
||||
logger.info(f"Sent delete tenant message for {tenant_id}")
|
||||
await asyncio.sleep(2) # Wait for queue deletion
|
||||
|
||||
await handler.connection.close()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user