feat: align async queue manager

This commit is contained in:
Jonathan Kössler 2024-07-12 15:14:13 +02:00
parent 9c28498d8a
commit 02665a5ef8
4 changed files with 19 additions and 468 deletions

View File

@ -5,7 +5,7 @@ from fastapi import FastAPI
from kn_utils.logging import logger
from pyinfra.config.loader import get_pyinfra_validators, validate_settings
from pyinfra.queue.async_tenants_v2 import RabbitMQConfig, RabbitMQHandler
from pyinfra.queue.async_manager import AsyncQueueManager, RabbitMQConfig
from pyinfra.queue.callback import Callback
from pyinfra.queue.manager import QueueManager
from pyinfra.utils.opentelemetry import instrument_app, instrument_pika, setup_trace
@ -27,11 +27,13 @@ def get_rabbitmq_config(settings: Dynaconf):
password=settings.rabbitmq.password,
heartbeat=settings.rabbitmq.heartbeat,
input_queue_prefix=settings.rabbitmq.service_request_queue_prefix,
tenant_event_queue_suffix=settings.rabbitmq.tenant_event_queue_suffix,
tenant_exchange_name=settings.rabbitmq.service_response_queue_prefix,
service_request_exchange_name=settings.rabbitmq.service_request_exchange_name,
service_response_exchange_name=settings.rabbitmq.service_response_exchange_name,
service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name,
queue_expiration_time=settings.rabbitmq.queue_expiration_time,
pod_name=settings.kubernetes.pod_name,
)
@ -66,7 +68,7 @@ def start_standard_queue_consumer(
if settings.concurrency.enabled:
config = get_rabbitmq_config(settings)
manager = RabbitMQHandler(
manager = AsyncQueueManager(
config=config, tenant_service_url=settings.storage.tenant_server.endpoint, message_processor=callback
)
else:
@ -77,7 +79,7 @@ def start_standard_queue_consumer(
webserver_thread = create_webserver_thread_from_settings(app, settings)
webserver_thread.start()
if isinstance(manager, RabbitMQHandler):
if isinstance(manager, AsyncQueueManager):
asyncio.run(manager.run())
elif isinstance(manager, QueueManager):
manager.start_consuming(callback)

View File

@ -23,11 +23,13 @@ class RabbitMQConfig:
password: str
heartbeat: int
input_queue_prefix: str
tenant_event_queue_suffix: str
tenant_exchange_name: str
service_request_exchange_name: str
service_response_exchange_name: str
service_dead_letter_queue_name: str
queue_expiration_time: int
pod_name: str
connection_params: Dict[str, object] = field(init=False)
@ -41,7 +43,7 @@ class RabbitMQConfig:
}
class RabbitMQHandler:
class AsyncQueueManager:
def __init__(
self,
config: RabbitMQConfig,
@ -80,9 +82,8 @@ class RabbitMQHandler:
)
async def setup_tenant_queue(self) -> None:
# TODO: Add k8s pod_name to tenant queue name - add DLQ?
queue = await self.channel.declare_queue(
"tenant_queue",
f"{self.config.pod_name}_{self.config.tenant_event_queue_suffix}",
durable=True,
arguments={
"x-dead-letter-exchange": "",
@ -190,13 +191,14 @@ class RabbitMQHandler:
try:
async with aiohttp.ClientSession() as session:
async with session.get(self.tenant_service_url) as response:
# TODO: dont know if we should check for 200, could also be 2xx
# maybe handle bad requests with response.raise_for_status()
if response.status == 200:
response.raise_for_status()
if response.headers["content-type"].lower() == "application/json":
data = await response.json()
return {tenant["tenantId"] for tenant in data}
else:
logger.error(f"Failed to fetch active tenants. Status: {response.status}")
logger.error(
f"Failed to fetch active tenants. Content type is not JSON: {response.headers['content-type'].lower()}"
)
return set()
except aiohttp.ClientError as e:
logger.error(f"Error fetching active tenants: {e}")
@ -227,7 +229,7 @@ class RabbitMQHandler:
logger.info("RabbitMQ handler is running. Press CTRL+C to exit.")
await stop.wait() # Run until stop signal received
except asyncio.CancelledError:
logger.info("Operation cancelled.")
logger.warning("Operation cancelled.")
except Exception as e:
logger.error(f"An error occurred: {e}", exc_info=True)
finally:
@ -240,5 +242,3 @@ class RabbitMQHandler:
if self.connection:
await self.connection.close()
logger.info("RabbitMQ handler shut down successfully.")
# TODO: purge_queues

View File

@ -1,435 +0,0 @@
import asyncio
import concurrent.futures
import datetime
import json
import time
import uuid
from typing import Callable, Union
import aio_pika
import aiormq
import requests
from aio_pika import DeliveryMode, Message
from aio_pika.abc import AbstractIncomingMessage
from dynaconf import Dynaconf
from kn_utils.logging import logger
from pyinfra.config.loader import (
load_settings,
local_pyinfra_root_path,
validate_settings,
)
from pyinfra.config.validators import queue_manager_validators
MessageProcessor = Callable[[dict], dict]
class AsyncQueueManager:
def __init__(self, settings: Dynaconf, message_processor: Callable = None) -> None:
validate_settings(settings, queue_manager_validators)
self.message_processor = message_processor
self.connection_params = self.get_connection_params(settings)
self.connection = None
self.channel = None
self.active_tenants = self.get_initial_tenant_ids(tenant_endpoint_url=settings.storage.tenant_server.endpoint)
self.consumer_tasks = {}
self.connection_sleep = settings.rabbitmq.connection_sleep
self.queue_expiration_time = settings.rabbitmq.queue_expiration_time
self.tenant_created_queue_name = self.get_tenant_created_queue_name(settings)
self.tenant_deleted_queue_name = self.get_tenant_deleted_queue_name(settings)
self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings)
self.tenant_exchange_name = settings.rabbitmq.tenant_exchange_name
self.service_request_exchange_name = settings.rabbitmq.service_request_exchange_name
self.service_response_exchange_name = settings.rabbitmq.service_response_exchange_name
self.service_request_queue_prefix = settings.rabbitmq.service_request_queue_prefix
self.service_response_queue_prefix = settings.rabbitmq.service_response_queue_prefix
self.service_dlq_name = settings.rabbitmq.service_dlq_name
@staticmethod
def get_connection_params(settings: Dynaconf):
return {
"host": settings.rabbitmq.host,
"port": settings.rabbitmq.port,
"login": settings.rabbitmq.username,
"password": settings.rabbitmq.password,
"client_properties": {"heartbeat": settings.rabbitmq.heartbeat},
}
def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> set:
response = requests.get(tenant_endpoint_url, timeout=10)
response.raise_for_status() # Raise an HTTPError for bad responses
if response.headers["content-type"].lower() == "application/json":
tenants = {tenant["tenantId"] for tenant in response.json()}
return tenants
return set()
def get_tenant_created_queue_name(self, settings: Dynaconf) -> str:
return self.get_queue_name_with_suffix(
suffix=settings.rabbitmq.tenant_created_event_queue_suffix, pod_name=settings.kubernetes.pod_name
)
def get_tenant_deleted_queue_name(self, settings: Dynaconf) -> str:
return self.get_queue_name_with_suffix(
suffix=settings.rabbitmq.tenant_deleted_event_queue_suffix, pod_name=settings.kubernetes.pod_name
)
def get_tenant_events_dlq_name(self, settings: Dynaconf) -> str:
return self.get_queue_name_with_suffix(
suffix=settings.rabbitmq.tenant_event_dlq_suffix, pod_name=settings.kubernetes.pod_name
)
def get_queue_name_with_suffix(self, suffix: str, pod_name: str) -> str:
if not self.use_default_queue_name() and pod_name:
return f"{pod_name}{suffix}"
return self.get_default_queue_name()
def use_default_queue_name(self) -> bool:
return False
def get_default_queue_name(self):
raise NotImplementedError("Queue name method not implemented")
async def is_ready(self) -> bool:
await self.connect()
return self.channel.is_open
#### ASYNC STUFF
async def purge_queues(self) -> None:
await self.establish_connection()
try:
for tenant_id in self.active_tenants:
service_request_queue = await self.channel.get_queue(f"{self.service_request_queue_prefix}_{tenant_id}")
await service_request_queue.purge()
service_response_queue = await self.channel.get_queue(
f"{self.service_response_queue_prefix}_{tenant_id}"
)
await service_response_queue.purge()
logger.info("Queues purged.")
except aio_pika.exceptions.ChannelInvalidStateError:
pass
async def connect(self):
self.connection = await aio_pika.connect_robust(**self.connection_params)
self.channel = await self.connection.channel()
logger.info("Connection established.")
async def establish_connection(self):
await self.connect()
await self.initialize_queues()
logger.info("Queues initialized.")
# await self.start_processing()
async def start_processing(self):
await self.establish_connection()
tenant_events = asyncio.create_task(self.handle_tenant_events())
service_events = asyncio.create_task(self.start_consumers())
await asyncio.gather(tenant_events, service_events)
async def initialize_queues(self):
await self.channel.set_qos(prefetch_count=1)
service_request_exchange = await self.channel.declare_exchange(
name=self.service_request_exchange_name, type=aio_pika.ExchangeType.DIRECT, durable=True
)
service_response_exchange = await self.channel.declare_exchange(
name=self.service_response_exchange_name, type=aio_pika.ExchangeType.DIRECT, durable=True
)
for tenant_id in self.active_tenants:
request_queue_name = f"{self.service_request_queue_prefix}_{tenant_id}"
request_queue = await self.channel.declare_queue(
name=request_queue_name,
durable=True,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.service_dlq_name,
"x-expires": self.queue_expiration_time, # TODO: check if necessary
"x-max-priority": 2,
},
)
await request_queue.bind(exchange=service_request_exchange, routing_key=tenant_id)
response_queue_name = f"{self.service_response_queue_prefix}_{tenant_id}"
response_queue = await self.channel.declare_queue(
name=response_queue_name,
durable=True,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.service_dlq_name,
"x-expires": self.queue_expiration_time, # TODO: check if necessary
},
)
await response_queue.bind(exchange=service_response_exchange, routing_key=tenant_id)
async def handle_tenant_events(self):
# Declare the topic exchange for tenant events
exchange = await self.channel.declare_exchange(
self.tenant_exchange_name, aio_pika.ExchangeType.TOPIC, durable=True
)
# Declare a queue for receiving tenant events
queue = await self.channel.declare_queue(
"tenant_events_queue",
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.tenant_events_dlq_name,
},
durable=True,
)
await queue.bind(exchange, routing_key="tenant.*")
async with queue.iterator() as queue_iter:
async for message in queue_iter:
async with message.process(reject_on_redelivered=True):
routing_key = message.routing_key
message_body = json.loads(message.body.decode())
tenant_id = message_body["tenantId"]
if routing_key == "tenant.created":
# Handle tenant creation
await self.handle_tenant_created(tenant_id)
elif routing_key == "tenant.deleted":
# Handle tenant deletion
await self.handle_tenant_deleted(tenant_id)
else:
message.nack()
continue
message.ack()
await self.restart_consumers()
async def handle_tenant_created(self, tenant_id):
# Handle creation of input and output queues for the new tenant
await self.create_tenant_queues(tenant_id)
await self.restart_consumers()
async def handle_tenant_deleted(self, tenant_id):
# Handle deletion of input and output queues for the tenant
await self.delete_tenant_queues(tenant_id)
await self.restart_consumers()
async def create_tenant_queues(self, tenant_id):
# Implement queue creation logic for the tenant
queue_name = f"{self.service_request_queue_prefix}_{tenant_id}"
queue = await self.channel.declare_queue(
name=queue_name,
durable=True,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.service_dlq_name,
"x-expires": self.queue_expiration_time, # TODO: check if necessary
},
)
exchange = await self.channel.get_exchange(self.service_request_exchange_name)
await queue.bind(exchange=exchange, routing_key=tenant_id)
self.active_tenants.add(tenant_id)
logger.info(f"Created queue for tenant {tenant_id}")
async def delete_tenant_queues(self, tenant_id):
queue_name = f"{self.service_request_queue_prefix}_{tenant_id}"
queue = await self.channel.get_queue(queue_name)
exchange = await self.channel.get_exchange(self.service_request_exchange_name)
await queue.unbind(exchange=exchange, routing_key=tenant_id)
await self.channel.queue_delete(queue_name)
self.active_tenants.discard(tenant_id)
async def consume_from_request_queue(self, tenant_id):
queue_name = f"{self.service_request_queue_prefix}_{tenant_id}"
queue = await self.channel.get_queue(queue_name)
async with queue.iterator() as queue_iter:
async for message in queue_iter:
async with message.process():
on_message_callback = await self._make_on_message_callback(self.message_processor, tenant_id)
await on_message_callback(message)
async def publish_to_service_response_queue(self, tenant_id, result):
service_response_exchange = await self.channel.get_exchange(self.service_response_exchange_name)
await service_response_exchange.publish(
Message(
body=json.dumps(result).encode(),
delivery_mode=DeliveryMode.NOT_PERSISTENT,
timestamp=datetime.datetime.now(),
message_id=str(uuid.uuid4()),
),
routing_key=tenant_id,
)
async def restart_consumers(self):
# Stop current consumers and start new ones for active tenants
await self.stop_consumers()
await self.start_consumers()
async def start_consumers(self):
# Start consuming messages from input queues for active tenants
for tenant_id in self.active_tenants:
if tenant_id not in self.consumer_tasks:
self.consumer_tasks[tenant_id] = asyncio.create_task(self.consume_from_request_queue(tenant_id))
consumer_tasks = [self.consume_from_request_queue(tenant) for tenant in self.active_tenants]
await asyncio.gather(*consumer_tasks)
async def stop_consumers(self):
for task in self.consumer_tasks.values():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self.consumer_tasks.clear()
async def main_loop(self):
await self.establish_connection()
async def shutdown(self):
# Implement cleanup logic
await self.stop_consumers()
if self.connection:
await self.connection.close()
async def _make_on_message_callback(self, message_processor: MessageProcessor, tenant_id: str) -> Callable:
async def process_message_body_and_await_result(unpacked_message_body):
# Processing the message in a separate thread is necessary for the main thread pika client to be able to
# process data events (e.g. heartbeats) while the message is being processed.
# with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
# logger.info("Processing payload in separate thread.")
# future = thread_pool_executor.submit(message_processor, unpacked_message_body)
# return future.result()
return {"result": "lovely"}
async def on_message_callback(message: AbstractIncomingMessage):
logger.info(f"Received message from queue with delivery_tag {message.delivery_tag}.")
if message.redelivered:
logger.warning(f"Declining message with {message.delivery_tag=} due to it being redelivered.")
await message.nack(requeue=False)
return
if message.body.decode("utf-8") == "STOP":
logger.info("Received stop signal, stopping consuming...")
await message.ack()
await self.stop_consumers()
return
try:
filtered_message_headers = (
{k: v for k, v in message.properties.headers.items() if k.lower().startswith("x-")}
if message.properties.headers
else {}
)
logger.debug(f"Processing message with {filtered_message_headers=}.")
result: dict = await (
process_message_body_and_await_result({**json.loads(message.body), **filtered_message_headers})
or {}
)
await self.publish_to_service_response_queue(tenant_id, result)
logger.info(f"Published result to queue {tenant_id}.")
await message.ack()
logger.debug(f"Message with {message.delivery_tag=} acknowledged.")
except FileNotFoundError as e:
logger.warning(f"{e}, declining message with {message.delivery_tag=}.")
await message.nack(requeue=False)
except Exception as e:
logger.warning(f"Failed to process message with {message.delivery_tag=}, declining...", exc_info=True)
logger.warning(e)
await message.nack(requeue=False)
raise
return on_message_callback
async def publish_message_to_input_queue(self, tenant_id: str, message: Union[str, bytes, dict]) -> None:
if isinstance(message, str):
message = message.encode("utf-8")
elif isinstance(message, dict):
message = json.dumps(message).encode("utf-8")
await self.establish_connection()
service_request_exchange = await self.channel.get_exchange(self.service_request_exchange_name)
await service_request_exchange.publish(
message=Message(
body=message,
delivery_mode=DeliveryMode.NOT_PERSISTENT,
timestamp=datetime.datetime.now(),
message_id=str(uuid.uuid4()),
),
routing_key=tenant_id,
)
logger.info(f"Published message to queue {tenant_id}.")
async def publish_message_to_tenant_created_queue(self, message: Union[str, bytes, dict]) -> None:
if isinstance(message, str):
message = message.encode("utf-8")
elif isinstance(message, dict):
message = json.dumps(message).encode("utf-8")
await self.establish_connection()
service_request_exchange = await self.channel.get_exchange(self.tenant_exchange_name)
await service_request_exchange.publish(
message=Message(
body=message,
delivery_mode=DeliveryMode.NOT_PERSISTENT,
timestamp=datetime.datetime.now(),
message_id=str(uuid.uuid4()),
),
routing_key="tenant.created",
)
logger.info(f"Published message to queue {self.tenant_created_queue_name}.")
async def publish_message_to_tenant_deleted_queue(self, message: Union[str, bytes, dict]) -> None:
if isinstance(message, str):
message = message.encode("utf-8")
elif isinstance(message, dict):
message = json.dumps(message).encode("utf-8")
await self.establish_connection()
service_request_exchange = await self.channel.get_exchange(self.tenant_exchange_name)
await service_request_exchange.publish(
message=Message(
body=message,
delivery_mode=DeliveryMode.NOT_PERSISTENT,
timestamp=datetime.datetime.now(),
message_id=str(uuid.uuid4()),
),
routing_key="tenant.delete",
)
logger.info(f"Published message to queue {self.tenant_deleted_queue_name}.")
async def main() -> None:
import time
settings = load_settings(local_pyinfra_root_path / "config/")
callback = ""
manager = AsyncQueueManager(settings=settings, message_processor=callback)
await manager.main_loop()
while True:
time.sleep(100)
print("keep idling")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -9,7 +9,8 @@ from aio_pika.abc import AbstractIncomingMessage
from kn_utils.logging import logger
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
from pyinfra.queue.async_tenants_v2 import RabbitMQConfig, RabbitMQHandler
from pyinfra.examples import get_rabbitmq_config
from pyinfra.queue.async_manager import AsyncQueueManager
from pyinfra.storage.storages.s3 import S3Storage, get_s3_storage_from_settings
settings = load_settings(local_pyinfra_root_path / "config/")
@ -88,29 +89,12 @@ def upload_json_and_make_message_body(tenant_id: str):
async def test_rabbitmq_handler() -> None:
tenant_service_url = settings.storage.tenant_server.endpoint
config = RabbitMQConfig(
host=settings.rabbitmq.host,
port=settings.rabbitmq.port,
username=settings.rabbitmq.username,
password=settings.rabbitmq.password,
heartbeat=settings.rabbitmq.heartbeat,
input_queue_prefix=settings.rabbitmq.service_request_queue_prefix,
tenant_exchange_name=settings.rabbitmq.service_response_queue_prefix,
service_request_exchange_name=settings.rabbitmq.service_request_exchange_name,
service_response_exchange_name=settings.rabbitmq.service_response_exchange_name,
service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name,
queue_expiration_time=settings.rabbitmq.queue_expiration_time,
)
config = get_rabbitmq_config(settings)
handler = RabbitMQHandler(config, tenant_service_url, dummy_message_processor)
handler = AsyncQueueManager(config, tenant_service_url, dummy_message_processor)
await handler.connect()
await handler.setup_exchanges()
# await handler.initialize_tenant_queues()
# await handler.setup_tenant_queue()
# for queue in handler.tenant_queues.values():
# await queue.purge()
tenant_id = "test_tenant"