feat: add async_v2
This commit is contained in:
parent
a5162d5bf0
commit
8844df44ce
@ -3,12 +3,14 @@ import asyncio
|
||||
from dynaconf import Dynaconf
|
||||
from fastapi import FastAPI
|
||||
from kn_utils.logging import logger
|
||||
from threading import Thread
|
||||
|
||||
# from threading import Thread
|
||||
from pyinfra.config.loader import get_pyinfra_validators, validate_settings
|
||||
|
||||
# from pyinfra.queue.threaded_tenants import ServiceQueueManager, TenantQueueManager
|
||||
from pyinfra.queue.async_tenants_v2 import RabbitMQConfig, RabbitMQHandler
|
||||
from pyinfra.queue.callback import Callback
|
||||
from pyinfra.queue.threaded_tenants import ServiceQueueManager, TenantQueueManager
|
||||
from pyinfra.queue.async_tenants import AsyncQueueManager
|
||||
from pyinfra.utils.opentelemetry import instrument_pika, setup_trace, instrument_app
|
||||
from pyinfra.utils.opentelemetry import instrument_app, instrument_pika, setup_trace
|
||||
from pyinfra.webserver.prometheus import (
|
||||
add_prometheus_endpoint,
|
||||
make_prometheus_processing_time_decorator_from_settings,
|
||||
@ -19,6 +21,22 @@ from pyinfra.webserver.utils import (
|
||||
)
|
||||
|
||||
|
||||
def get_rabbitmq_config(settings: Dynaconf):
|
||||
return RabbitMQConfig(
|
||||
host=settings.rabbitmq.host,
|
||||
port=settings.rabbitmq.port,
|
||||
username=settings.rabbitmq.username,
|
||||
password=settings.rabbitmq.password,
|
||||
heartbeat=settings.rabbitmq.heartbeat,
|
||||
input_queue_prefix=settings.rabbitmq.service_request_queue_prefix,
|
||||
tenant_exchange_name=settings.rabbitmq.service_response_queue_prefix,
|
||||
service_request_exchange_name=settings.rabbitmq.service_request_exchange_name,
|
||||
service_response_exchange_name=settings.rabbitmq.service_response_exchange_name,
|
||||
service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name,
|
||||
queue_expiration_time=settings.rabbitmq.queue_expiration_time,
|
||||
)
|
||||
|
||||
|
||||
def start_standard_queue_consumer(
|
||||
callback: Callback,
|
||||
settings: Dynaconf,
|
||||
@ -37,8 +55,8 @@ def start_standard_queue_consumer(
|
||||
|
||||
app = app or FastAPI()
|
||||
|
||||
tenant_manager = TenantQueueManager(settings)
|
||||
service_manager = ServiceQueueManager(settings)
|
||||
# tenant_manager = TenantQueueManager(settings)
|
||||
# service_manager = ServiceQueueManager(settings)
|
||||
|
||||
if settings.metrics.prometheus.enabled:
|
||||
logger.info("Prometheus metrics enabled.")
|
||||
@ -52,20 +70,24 @@ def start_standard_queue_consumer(
|
||||
instrument_app(app)
|
||||
|
||||
# manager = AsyncQueueManager(settings=settings, message_processor=callback)
|
||||
config = get_rabbitmq_config(settings)
|
||||
manager = RabbitMQHandler(
|
||||
config=config, tenant_service_url=settings.storage.tenant_server.endpoint, message_processor=callback
|
||||
)
|
||||
|
||||
app = add_health_check_endpoint(app, service_manager.is_ready)
|
||||
# app = add_health_check_endpoint(app, manager.is_ready)
|
||||
# app = add_health_check_endpoint(app, service_manager.is_ready)
|
||||
app = add_health_check_endpoint(app, manager.is_ready)
|
||||
|
||||
webserver_thread = create_webserver_thread_from_settings(app, settings)
|
||||
webserver_thread.start()
|
||||
|
||||
thread_t = Thread(target=tenant_manager.start_consuming, daemon=True)
|
||||
thread_s = Thread(target=service_manager.start_sequential_basic_get, args=(callback,), daemon=True)
|
||||
# thread_t = Thread(target=tenant_manager.start_consuming, daemon=True)
|
||||
# thread_s = Thread(target=service_manager.start_sequential_basic_get, args=(callback,), daemon=True)
|
||||
|
||||
thread_t.start()
|
||||
thread_s.start()
|
||||
# thread_t.start()
|
||||
# thread_s.start()
|
||||
|
||||
thread_t.join()
|
||||
thread_s.join()
|
||||
|
||||
# asyncio.run(manager.start_processing())
|
||||
# thread_t.join()
|
||||
# thread_s.join()
|
||||
|
||||
asyncio.run(manager.run())
|
||||
|
||||
@ -1,28 +1,31 @@
|
||||
import aiormq
|
||||
import asyncio
|
||||
import aio_pika
|
||||
import concurrent.futures
|
||||
import requests
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from aio_pika import Message, DeliveryMode
|
||||
from aio_pika.abc import AbstractIncomingMessage
|
||||
from dynaconf import Dynaconf
|
||||
import time
|
||||
import uuid
|
||||
from typing import Callable, Union
|
||||
|
||||
import aio_pika
|
||||
import aiormq
|
||||
import requests
|
||||
from aio_pika import DeliveryMode, Message
|
||||
from aio_pika.abc import AbstractIncomingMessage
|
||||
from dynaconf import Dynaconf
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.config.loader import validate_settings
|
||||
from pyinfra.config.loader import (
|
||||
load_settings,
|
||||
local_pyinfra_root_path,
|
||||
validate_settings,
|
||||
)
|
||||
from pyinfra.config.validators import queue_manager_validators
|
||||
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
|
||||
from pyinfra.queue.callback import make_download_process_upload_callback
|
||||
|
||||
|
||||
MessageProcessor = Callable[[dict], dict]
|
||||
|
||||
|
||||
class AsyncQueueManager:
|
||||
|
||||
|
||||
def __init__(self, settings: Dynaconf, message_processor: Callable = None) -> None:
|
||||
validate_settings(settings, queue_manager_validators)
|
||||
|
||||
@ -36,12 +39,12 @@ class AsyncQueueManager:
|
||||
|
||||
self.connection_sleep = settings.rabbitmq.connection_sleep
|
||||
self.queue_expiration_time = settings.rabbitmq.queue_expiration_time
|
||||
|
||||
|
||||
self.tenant_created_queue_name = self.get_tenant_created_queue_name(settings)
|
||||
self.tenant_deleted_queue_name = self.get_tenant_deleted_queue_name(settings)
|
||||
self.tenant_events_dlq_name = self.get_tenant_events_dlq_name(settings)
|
||||
self.tenant_exchange_name = settings.rabbitmq.tenant_exchange_name
|
||||
|
||||
|
||||
self.service_request_exchange_name = settings.rabbitmq.service_request_exchange_name
|
||||
self.service_response_exchange_name = settings.rabbitmq.service_response_exchange_name
|
||||
|
||||
@ -50,18 +53,16 @@ class AsyncQueueManager:
|
||||
|
||||
self.service_dlq_name = settings.rabbitmq.service_dlq_name
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_connection_params(settings: Dynaconf):
|
||||
return {
|
||||
|
||||
"host": settings.rabbitmq.host,
|
||||
"port": settings.rabbitmq.port,
|
||||
"login": settings.rabbitmq.username,
|
||||
"password":settings.rabbitmq.password,
|
||||
"client_properties": {"heartbeat": settings.rabbitmq.heartbeat}
|
||||
"password": settings.rabbitmq.password,
|
||||
"client_properties": {"heartbeat": settings.rabbitmq.heartbeat},
|
||||
}
|
||||
|
||||
|
||||
def get_initial_tenant_ids(self, tenant_endpoint_url: str) -> set:
|
||||
response = requests.get(tenant_endpoint_url, timeout=10)
|
||||
response.raise_for_status() # Raise an HTTPError for bad responses
|
||||
@ -70,7 +71,7 @@ class AsyncQueueManager:
|
||||
tenants = {tenant["tenantId"] for tenant in response.json()}
|
||||
return tenants
|
||||
return set()
|
||||
|
||||
|
||||
def get_tenant_created_queue_name(self, settings: Dynaconf) -> str:
|
||||
return self.get_queue_name_with_suffix(
|
||||
suffix=settings.rabbitmq.tenant_created_event_queue_suffix, pod_name=settings.kubernetes.pod_name
|
||||
@ -96,7 +97,7 @@ class AsyncQueueManager:
|
||||
|
||||
def get_default_queue_name(self):
|
||||
raise NotImplementedError("Queue name method not implemented")
|
||||
|
||||
|
||||
async def is_ready(self) -> bool:
|
||||
await self.connect()
|
||||
return self.channel.is_open
|
||||
@ -108,7 +109,9 @@ class AsyncQueueManager:
|
||||
for tenant_id in self.active_tenants:
|
||||
service_request_queue = await self.channel.get_queue(f"{self.service_request_queue_prefix}_{tenant_id}")
|
||||
await service_request_queue.purge()
|
||||
service_response_queue = await self.channel.get_queue(f"{self.service_response_queue_prefix}_{tenant_id}")
|
||||
service_response_queue = await self.channel.get_queue(
|
||||
f"{self.service_response_queue_prefix}_{tenant_id}"
|
||||
)
|
||||
await service_response_queue.purge()
|
||||
logger.info("Queues purged.")
|
||||
except aio_pika.exceptions.ChannelInvalidStateError:
|
||||
@ -132,12 +135,15 @@ class AsyncQueueManager:
|
||||
|
||||
await asyncio.gather(tenant_events, service_events)
|
||||
|
||||
|
||||
async def initialize_queues(self):
|
||||
await self.channel.set_qos(prefetch_count=1)
|
||||
|
||||
service_request_exchange = await self.channel.declare_exchange(name=self.service_request_exchange_name, type=aio_pika.ExchangeType.DIRECT)
|
||||
service_response_exchange = await self.channel.declare_exchange(name=self.service_response_exchange_name, type=aio_pika.ExchangeType.DIRECT)
|
||||
service_request_exchange = await self.channel.declare_exchange(
|
||||
name=self.service_request_exchange_name, type=aio_pika.ExchangeType.DIRECT, durable=True
|
||||
)
|
||||
service_response_exchange = await self.channel.declare_exchange(
|
||||
name=self.service_response_exchange_name, type=aio_pika.ExchangeType.DIRECT, durable=True
|
||||
)
|
||||
|
||||
for tenant_id in self.active_tenants:
|
||||
request_queue_name = f"{self.service_request_queue_prefix}_{tenant_id}"
|
||||
@ -164,26 +170,30 @@ class AsyncQueueManager:
|
||||
},
|
||||
)
|
||||
await response_queue.bind(exchange=service_response_exchange, routing_key=tenant_id)
|
||||
|
||||
|
||||
async def handle_tenant_events(self):
|
||||
# Declare the topic exchange for tenant events
|
||||
exchange = await self.channel.declare_exchange(
|
||||
self.tenant_exchange_name, aio_pika.ExchangeType.TOPIC
|
||||
self.tenant_exchange_name, aio_pika.ExchangeType.TOPIC, durable=True
|
||||
)
|
||||
# Declare a queue for receiving tenant events
|
||||
queue = await self.channel.declare_queue("tenant_events_queue", arguments={
|
||||
queue = await self.channel.declare_queue(
|
||||
"tenant_events_queue",
|
||||
arguments={
|
||||
"x-dead-letter-exchange": "",
|
||||
"x-dead-letter-routing-key": self.tenant_events_dlq_name,
|
||||
},
|
||||
durable=True,)
|
||||
|
||||
durable=True,
|
||||
)
|
||||
|
||||
await queue.bind(exchange, routing_key="tenant.*")
|
||||
|
||||
async with queue.iterator() as queue_iter:
|
||||
async for message in queue_iter:
|
||||
async with message.process(reject_on_redelivered=True):
|
||||
routing_key = message.routing_key
|
||||
tenant_id = message.body.decode()
|
||||
message_body = json.loads(message.body.decode())
|
||||
tenant_id = message_body["tenantId"]
|
||||
if routing_key == "tenant.created":
|
||||
# Handle tenant creation
|
||||
await self.handle_tenant_created(tenant_id)
|
||||
@ -222,7 +232,7 @@ class AsyncQueueManager:
|
||||
exchange = await self.channel.get_exchange(self.service_request_exchange_name)
|
||||
await queue.bind(exchange=exchange, routing_key=tenant_id)
|
||||
self.active_tenants.add(tenant_id)
|
||||
|
||||
logger.info(f"Created queue for tenant {tenant_id}")
|
||||
|
||||
async def delete_tenant_queues(self, tenant_id):
|
||||
queue_name = f"{self.service_request_queue_prefix}_{tenant_id}"
|
||||
@ -245,8 +255,15 @@ class AsyncQueueManager:
|
||||
async def publish_to_service_response_queue(self, tenant_id, result):
|
||||
service_response_exchange = await self.channel.get_exchange(self.service_response_exchange_name)
|
||||
|
||||
await service_response_exchange.publish(aio_pika.Message(body=json.dumps(result).encode()),
|
||||
routing_key=tenant_id,)
|
||||
await service_response_exchange.publish(
|
||||
Message(
|
||||
body=json.dumps(result).encode(),
|
||||
delivery_mode=DeliveryMode.NOT_PERSISTENT,
|
||||
timestamp=datetime.datetime.now(),
|
||||
message_id=str(uuid.uuid4()),
|
||||
),
|
||||
routing_key=tenant_id,
|
||||
)
|
||||
|
||||
async def restart_consumers(self):
|
||||
# Stop current consumers and start new ones for active tenants
|
||||
@ -284,11 +301,13 @@ class AsyncQueueManager:
|
||||
async def process_message_body_and_await_result(unpacked_message_body):
|
||||
# Processing the message in a separate thread is necessary for the main thread pika client to be able to
|
||||
# process data events (e.g. heartbeats) while the message is being processed.
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
||||
logger.info("Processing payload in separate thread.")
|
||||
future = thread_pool_executor.submit(message_processor, unpacked_message_body)
|
||||
# with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
||||
# logger.info("Processing payload in separate thread.")
|
||||
# future = thread_pool_executor.submit(message_processor, unpacked_message_body)
|
||||
|
||||
return future.result()
|
||||
# return future.result()
|
||||
|
||||
return {"result": "lovely"}
|
||||
|
||||
async def on_message_callback(message: AbstractIncomingMessage):
|
||||
logger.info(f"Received message from queue with delivery_tag {message.delivery_tag}.")
|
||||
@ -312,12 +331,13 @@ class AsyncQueueManager:
|
||||
)
|
||||
logger.debug(f"Processing message with {filtered_message_headers=}.")
|
||||
result: dict = await (
|
||||
process_message_body_and_await_result({**json.loads(message.body), **filtered_message_headers}) or {}
|
||||
process_message_body_and_await_result({**json.loads(message.body), **filtered_message_headers})
|
||||
or {}
|
||||
)
|
||||
|
||||
await self.publish_to_service_response_queue(tenant_id, result)
|
||||
logger.info(f"Published result to queue {tenant_id}.")
|
||||
|
||||
|
||||
await message.ack()
|
||||
logger.debug(f"Message with {message.delivery_tag=} acknowledged.")
|
||||
except FileNotFoundError as e:
|
||||
@ -330,9 +350,8 @@ class AsyncQueueManager:
|
||||
raise
|
||||
|
||||
return on_message_callback
|
||||
|
||||
async def publish_message_to_input_queue(
|
||||
self, tenant_id: str, message: Union[str, bytes, dict]) -> None:
|
||||
|
||||
async def publish_message_to_input_queue(self, tenant_id: str, message: Union[str, bytes, dict]) -> None:
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
elif isinstance(message, dict):
|
||||
@ -342,12 +361,19 @@ class AsyncQueueManager:
|
||||
|
||||
service_request_exchange = await self.channel.get_exchange(self.service_request_exchange_name)
|
||||
|
||||
await service_request_exchange.publish(message=Message(body=message, delivery_mode=DeliveryMode.NOT_PERSISTENT), routing_key=tenant_id)
|
||||
await service_request_exchange.publish(
|
||||
message=Message(
|
||||
body=message,
|
||||
delivery_mode=DeliveryMode.NOT_PERSISTENT,
|
||||
timestamp=datetime.datetime.now(),
|
||||
message_id=str(uuid.uuid4()),
|
||||
),
|
||||
routing_key=tenant_id,
|
||||
)
|
||||
|
||||
logger.info(f"Published message to queue {tenant_id}.")
|
||||
|
||||
async def publish_message_to_tenant_created_queue(
|
||||
self, message: Union[str, bytes, dict]) -> None:
|
||||
async def publish_message_to_tenant_created_queue(self, message: Union[str, bytes, dict]) -> None:
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
elif isinstance(message, dict):
|
||||
@ -356,12 +382,19 @@ class AsyncQueueManager:
|
||||
await self.establish_connection()
|
||||
service_request_exchange = await self.channel.get_exchange(self.tenant_exchange_name)
|
||||
|
||||
await service_request_exchange.publish(message=Message(body=message, delivery_mode=DeliveryMode.NOT_PERSISTENT), routing_key="tenant.created")
|
||||
await service_request_exchange.publish(
|
||||
message=Message(
|
||||
body=message,
|
||||
delivery_mode=DeliveryMode.NOT_PERSISTENT,
|
||||
timestamp=datetime.datetime.now(),
|
||||
message_id=str(uuid.uuid4()),
|
||||
),
|
||||
routing_key="tenant.created",
|
||||
)
|
||||
|
||||
logger.info(f"Published message to queue {self.tenant_created_queue_name}.")
|
||||
|
||||
async def publish_message_to_tenant_deleted_queue(
|
||||
self, message: Union[str, bytes, dict]) -> None:
|
||||
|
||||
async def publish_message_to_tenant_deleted_queue(self, message: Union[str, bytes, dict]) -> None:
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
elif isinstance(message, dict):
|
||||
@ -370,15 +403,22 @@ class AsyncQueueManager:
|
||||
await self.establish_connection()
|
||||
service_request_exchange = await self.channel.get_exchange(self.tenant_exchange_name)
|
||||
|
||||
await service_request_exchange.publish(message=Message(body=message, delivery_mode=DeliveryMode.NOT_PERSISTENT), routing_key="tenant.delete")
|
||||
await service_request_exchange.publish(
|
||||
message=Message(
|
||||
body=message,
|
||||
delivery_mode=DeliveryMode.NOT_PERSISTENT,
|
||||
timestamp=datetime.datetime.now(),
|
||||
message_id=str(uuid.uuid4()),
|
||||
),
|
||||
routing_key="tenant.delete",
|
||||
)
|
||||
|
||||
logger.info(f"Published message to queue {self.tenant_deleted_queue_name}.")
|
||||
|
||||
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
import time
|
||||
|
||||
settings = load_settings(local_pyinfra_root_path / "config/")
|
||||
callback = ""
|
||||
|
||||
@ -386,10 +426,10 @@ async def main() -> None:
|
||||
|
||||
await manager.main_loop()
|
||||
|
||||
|
||||
while True:
|
||||
time.sleep(100)
|
||||
print("keep idling")
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@ -1,17 +1,18 @@
|
||||
import asyncio
|
||||
import gzip
|
||||
from operator import itemgetter
|
||||
from typing import Any, Callable, Dict, Set
|
||||
import aiohttp
|
||||
from aio_pika import connect_robust, ExchangeType, Message, IncomingMessage
|
||||
from aio_pika.abc import AbstractIncomingMessage, AbstractChannel, AbstractConnection, AbstractExchange
|
||||
import json
|
||||
from kn_utils.logging import logger
|
||||
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
|
||||
import signal
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, Set
|
||||
|
||||
settings = load_settings(local_pyinfra_root_path / "config/")
|
||||
import aiohttp
|
||||
from aio_pika import ExchangeType, IncomingMessage, Message, connect_robust
|
||||
from aio_pika.abc import (
|
||||
AbstractChannel,
|
||||
AbstractConnection,
|
||||
AbstractExchange,
|
||||
AbstractIncomingMessage,
|
||||
)
|
||||
from kn_utils.logging import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -63,6 +64,10 @@ class RabbitMQHandler:
|
||||
self.channel = await self.connection.channel()
|
||||
await self.channel.set_qos(prefetch_count=1)
|
||||
|
||||
async def is_ready(self) -> bool:
|
||||
await self.connect()
|
||||
return self.channel.is_open
|
||||
|
||||
async def setup_exchanges(self):
|
||||
self.tenant_exchange = await self.channel.declare_exchange(
|
||||
self.config.tenant_exchange_name, ExchangeType.TOPIC, durable=True
|
||||
@ -75,6 +80,7 @@ class RabbitMQHandler:
|
||||
)
|
||||
|
||||
async def setup_tenant_queue(self) -> None:
|
||||
# TODO: Add k8s pod_name to tenant queue name - add DLQ?
|
||||
queue = await self.channel.declare_queue(
|
||||
"tenant_queue",
|
||||
durable=True,
|
||||
@ -93,7 +99,7 @@ class RabbitMQHandler:
|
||||
tenant_id = message_body["tenantId"]
|
||||
routing_key = message.routing_key
|
||||
|
||||
if routing_key == "tenant.create":
|
||||
if routing_key == "tenant.created":
|
||||
await self.create_tenant_queues(tenant_id)
|
||||
elif routing_key == "tenant.delete":
|
||||
await self.delete_tenant_queues(tenant_id)
|
||||
@ -119,16 +125,30 @@ class RabbitMQHandler:
|
||||
|
||||
async def delete_tenant_queues(self, tenant_id: str) -> None:
|
||||
if tenant_id in self.tenant_queues:
|
||||
input_queue = self.tenant_queues[tenant_id]
|
||||
await input_queue.delete()
|
||||
# somehow queue.delete() does not work here
|
||||
await self.channel.queue_delete(f"{self.config.input_queue_prefix}_{tenant_id}")
|
||||
del self.tenant_queues[tenant_id]
|
||||
logger.info(f"Deleted queues for tenant {tenant_id}")
|
||||
|
||||
async def process_input_message(self, message: IncomingMessage) -> None:
|
||||
async with message.process():
|
||||
async def process_message_body_and_await_result(unpacked_message_body):
|
||||
return self.message_processor(unpacked_message_body)
|
||||
|
||||
async with message.process(ignore_processed=True):
|
||||
if message.redelivered:
|
||||
logger.warning(f"Declining message with {message.delivery_tag=} due to it being redelivered.")
|
||||
await message.nack(requeue=False)
|
||||
return
|
||||
|
||||
if message.body.decode("utf-8") == "STOP":
|
||||
logger.info("Received stop signal, stopping consumption...")
|
||||
await message.ack()
|
||||
# TODO: shutdown is probably not the right call here - align w/ Dev what should happen on stop signal
|
||||
await self.shutdown()
|
||||
return
|
||||
|
||||
try:
|
||||
tenant_id = message.routing_key
|
||||
message_body = json.loads(message.body.decode("utf-8"))
|
||||
|
||||
filtered_message_headers = (
|
||||
{k: v for k, v in message.headers.items() if k.lower().startswith("x-")} if message.headers else {}
|
||||
@ -136,18 +156,26 @@ class RabbitMQHandler:
|
||||
|
||||
logger.debug(f"Processing message with {filtered_message_headers=}.")
|
||||
|
||||
message_body.update(filtered_message_headers)
|
||||
|
||||
result = await self.message_processor(message_body)
|
||||
result: dict = await (
|
||||
process_message_body_and_await_result({**json.loads(message.body), **filtered_message_headers})
|
||||
or {}
|
||||
)
|
||||
|
||||
if result:
|
||||
await self.publish_to_output_exchange(tenant_id, result, filtered_message_headers)
|
||||
await message.ack()
|
||||
logger.debug(f"Message with {message.delivery_tag=} acknowledged.")
|
||||
else:
|
||||
raise ValueError(f"Could not process message with {message.body=}.")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
await message.nack(requeue=False)
|
||||
logger.error(f"Invalid JSON in input message: {message.body}")
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(f"{e}, declining message.")
|
||||
logger.warning(f"{e}, declining message with {message.delivery_tag=}.")
|
||||
await message.nack(requeue=False)
|
||||
except Exception as e:
|
||||
await message.nack(requeue=False)
|
||||
logger.error(f"Error processing input message: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
@ -162,6 +190,8 @@ class RabbitMQHandler:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self.tenant_service_url) as response:
|
||||
# TODO: dont know if we should check for 200, could also be 2xx
|
||||
# maybe handle bad requests with response.raise_for_status()
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return {tenant["tenantId"] for tenant in data}
|
||||
@ -178,105 +208,37 @@ class RabbitMQHandler:
|
||||
await self.create_tenant_queues(tenant_id)
|
||||
|
||||
async def run(self) -> None:
|
||||
stop = asyncio.Event()
|
||||
|
||||
def signal_handler(*_):
|
||||
logger.info("Signal received, shutting down...")
|
||||
stop.set()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
try:
|
||||
await self.connect()
|
||||
await self.setup_exchanges()
|
||||
await self.initialize_tenant_queues()
|
||||
await self.setup_tenant_queue()
|
||||
|
||||
logger.info("RabbitMQ handler is running. Press CTRL+C to exit.")
|
||||
await asyncio.Future() # Run forever
|
||||
await stop.wait() # Run until stop signal received
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Shutting down RabbitMQ handler...")
|
||||
logger.info("Operation cancelled.")
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred: {e}", exc_info=True)
|
||||
finally:
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
await self.shutdown()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.info("Shutting down RabbitMQ handler...")
|
||||
if self.channel:
|
||||
await self.channel.close()
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
logger.info("RabbitMQ handler shut down successfully.")
|
||||
|
||||
async def dummy_message_processor(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
logger.info(f"Processing message: {message}")
|
||||
await asyncio.sleep(1) # Simulate processing time
|
||||
|
||||
storage = get_s3_storage_from_settings(settings)
|
||||
tenant_id, dossier_id, file_id = itemgetter("tenantId", "dossierId", "fileId")(message)
|
||||
suffix = message["responseFileExtension"]
|
||||
|
||||
object_name = f"{tenant_id}/{dossier_id}/{file_id}.{message['targetFileExtension']}"
|
||||
original_content = json.loads(gzip.decompress(storage.get_object(object_name)))
|
||||
processed_content = {
|
||||
"processedPages": original_content["numberOfPages"],
|
||||
"processedSectionTexts": f"Processed: {original_content['sectionTexts']}",
|
||||
}
|
||||
|
||||
processed_object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}"
|
||||
processed_data = gzip.compress(json.dumps(processed_content).encode("utf-8"))
|
||||
storage.put_object(processed_object_name, processed_data)
|
||||
|
||||
processed_message = message.copy()
|
||||
processed_message["processed"] = True
|
||||
processed_message["processor_message"] = "This message was processed by the dummy processor"
|
||||
|
||||
logger.info(f"Finished processing message. Result: {processed_message}")
|
||||
return processed_message
|
||||
|
||||
|
||||
async def test_rabbitmq_handler() -> None:
|
||||
tenant_service_url = settings.storage.tenant_server.endpoint
|
||||
|
||||
config = RabbitMQConfig(
|
||||
host=settings.rabbitmq.host,
|
||||
port=settings.rabbitmq.port,
|
||||
username=settings.rabbitmq.username,
|
||||
password=settings.rabbitmq.password,
|
||||
heartbeat=settings.rabbitmq.heartbeat,
|
||||
input_queue_prefix=settings.rabbitmq.service_request_queue_prefix,
|
||||
tenant_exchange_name=settings.rabbitmq.service_response_queue_prefix,
|
||||
service_request_exchange_name=settings.rabbitmq.service_request_exchange_name,
|
||||
service_response_exchange_name=settings.rabbitmq.service_response_exchange_name,
|
||||
service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name,
|
||||
queue_expiration_time=settings.rabbitmq.queue_expiration_time,
|
||||
)
|
||||
|
||||
handler = RabbitMQHandler(config, tenant_service_url, dummy_message_processor)
|
||||
|
||||
await handler.connect()
|
||||
await handler.setup_exchanges()
|
||||
await handler.initialize_tenant_queues()
|
||||
await handler.setup_tenant_queue()
|
||||
|
||||
tenant_id = "test_tenant"
|
||||
|
||||
# Test tenant creation
|
||||
create_message = {"tenantId": tenant_id}
|
||||
await handler.tenant_exchange.publish(
|
||||
Message(body=json.dumps(create_message).encode()), routing_key="tenant.create"
|
||||
)
|
||||
logger.info(f"Sent create tenant message for {tenant_id}")
|
||||
await asyncio.sleep(2) # Wait for queue creation
|
||||
|
||||
# Test service request
|
||||
service_request = {
|
||||
"tenantId": tenant_id,
|
||||
"dossierId": "dossier",
|
||||
"fileId": "file",
|
||||
"targetFileExtension": "json.gz",
|
||||
"responseFileExtension": "result.json.gz",
|
||||
}
|
||||
await handler.input_exchange.publish(Message(body=json.dumps(service_request).encode()), routing_key=tenant_id)
|
||||
logger.info(f"Sent service request for {tenant_id}")
|
||||
await asyncio.sleep(5) # Wait for message processing
|
||||
|
||||
# Test tenant deletion
|
||||
delete_message = {"tenantId": tenant_id}
|
||||
await handler.tenant_exchange.publish(
|
||||
Message(body=json.dumps(delete_message).encode()), routing_key="tenant.delete"
|
||||
)
|
||||
logger.info(f"Sent delete tenant message for {tenant_id}")
|
||||
await asyncio.sleep(2) # Wait for queue deletion
|
||||
|
||||
await handler.connection.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_rabbitmq_handler())
|
||||
# TODO: purge_queues
|
||||
153
scripts/send_async_request.py
Normal file
153
scripts/send_async_request.py
Normal file
@ -0,0 +1,153 @@
|
||||
import asyncio
|
||||
import gzip
|
||||
import json
|
||||
from operator import itemgetter
|
||||
from typing import Any, Dict
|
||||
|
||||
from aio_pika import Message
|
||||
from aio_pika.abc import AbstractIncomingMessage
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
|
||||
from pyinfra.queue.async_tenants_v2 import RabbitMQConfig, RabbitMQHandler
|
||||
from pyinfra.storage.storages.s3 import S3Storage, get_s3_storage_from_settings
|
||||
|
||||
settings = load_settings(local_pyinfra_root_path / "config/")
|
||||
|
||||
|
||||
async def dummy_message_processor(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
logger.info(f"Processing message: {message}")
|
||||
# await asyncio.sleep(1) # Simulate processing time
|
||||
|
||||
storage = get_s3_storage_from_settings(settings)
|
||||
tenant_id, dossier_id, file_id = itemgetter("tenantId", "dossierId", "fileId")(message)
|
||||
suffix = message["responseFileExtension"]
|
||||
|
||||
object_name = f"{tenant_id}/{dossier_id}/{file_id}.{message['targetFileExtension']}"
|
||||
original_content = json.loads(gzip.decompress(storage.get_object(object_name)))
|
||||
processed_content = {
|
||||
"processedPages": original_content["numberOfPages"],
|
||||
"processedSectionTexts": f"Processed: {original_content['sectionTexts']}",
|
||||
}
|
||||
|
||||
processed_object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}"
|
||||
processed_data = gzip.compress(json.dumps(processed_content).encode("utf-8"))
|
||||
storage.put_object(processed_object_name, processed_data)
|
||||
|
||||
processed_message = message.copy()
|
||||
processed_message["processed"] = True
|
||||
processed_message["processor_message"] = "This message was processed by the dummy processor"
|
||||
|
||||
logger.info(f"Finished processing message. Result: {processed_message}")
|
||||
return processed_message
|
||||
|
||||
|
||||
async def on_response_message_callback(storage: S3Storage):
|
||||
async def on_message(message: AbstractIncomingMessage) -> None:
|
||||
async with message.process(ignore_processed=True):
|
||||
if not message.body:
|
||||
raise ValueError
|
||||
response = json.loads(message.body)
|
||||
logger.info(f"Received {response}")
|
||||
logger.info(f"Message headers: {message.properties.headers}")
|
||||
await message.ack()
|
||||
tenant_id, dossier_id, file_id = itemgetter("tenantId", "dossierId", "fileId")(response)
|
||||
suffix = response["responseFileExtension"]
|
||||
result = storage.get_object(f"{tenant_id}/{dossier_id}/{file_id}.{suffix}")
|
||||
result = json.loads(gzip.decompress(result))
|
||||
logger.info(f"Contents of result on storage: {result}")
|
||||
|
||||
return on_message
|
||||
|
||||
|
||||
def upload_json_and_make_message_body(tenant_id: str):
|
||||
dossier_id, file_id, suffix = "dossier", "file", "json.gz"
|
||||
content = {
|
||||
"numberOfPages": 7,
|
||||
"sectionTexts": "data",
|
||||
}
|
||||
|
||||
object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}"
|
||||
data = gzip.compress(json.dumps(content).encode("utf-8"))
|
||||
|
||||
storage = get_s3_storage_from_settings(settings)
|
||||
if not storage.has_bucket():
|
||||
storage.make_bucket()
|
||||
storage.put_object(object_name, data)
|
||||
|
||||
message_body = {
|
||||
"tenantId": tenant_id,
|
||||
"dossierId": dossier_id,
|
||||
"fileId": file_id,
|
||||
"targetFileExtension": suffix,
|
||||
"responseFileExtension": f"result.{suffix}",
|
||||
}
|
||||
return message_body, storage
|
||||
|
||||
|
||||
async def test_rabbitmq_handler() -> None:
|
||||
tenant_service_url = settings.storage.tenant_server.endpoint
|
||||
|
||||
config = RabbitMQConfig(
|
||||
host=settings.rabbitmq.host,
|
||||
port=settings.rabbitmq.port,
|
||||
username=settings.rabbitmq.username,
|
||||
password=settings.rabbitmq.password,
|
||||
heartbeat=settings.rabbitmq.heartbeat,
|
||||
input_queue_prefix=settings.rabbitmq.service_request_queue_prefix,
|
||||
tenant_exchange_name=settings.rabbitmq.service_response_queue_prefix,
|
||||
service_request_exchange_name=settings.rabbitmq.service_request_exchange_name,
|
||||
service_response_exchange_name=settings.rabbitmq.service_response_exchange_name,
|
||||
service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name,
|
||||
queue_expiration_time=settings.rabbitmq.queue_expiration_time,
|
||||
)
|
||||
|
||||
handler = RabbitMQHandler(config, tenant_service_url, dummy_message_processor)
|
||||
|
||||
await handler.connect()
|
||||
await handler.setup_exchanges()
|
||||
# await handler.initialize_tenant_queues()
|
||||
# await handler.setup_tenant_queue()
|
||||
|
||||
# for queue in handler.tenant_queues.values():
|
||||
# await queue.purge()
|
||||
|
||||
tenant_id = "test_tenant"
|
||||
|
||||
# Test tenant creation
|
||||
create_message = {"tenantId": tenant_id}
|
||||
await handler.tenant_exchange.publish(
|
||||
Message(body=json.dumps(create_message).encode()), routing_key="tenant.created"
|
||||
)
|
||||
logger.info(f"Sent create tenant message for {tenant_id}")
|
||||
await asyncio.sleep(2) # Wait for queue creation
|
||||
|
||||
# Prepare service request
|
||||
service_request, storage = upload_json_and_make_message_body(tenant_id)
|
||||
|
||||
# Test service request
|
||||
await handler.input_exchange.publish(Message(body=json.dumps(service_request).encode()), routing_key=tenant_id)
|
||||
logger.info(f"Sent service request for {tenant_id}")
|
||||
await asyncio.sleep(5) # Wait for message processing
|
||||
|
||||
# Consume service request
|
||||
response_queue = await handler.channel.declare_queue(name=f"response_queue_{tenant_id}")
|
||||
await response_queue.bind(exchange=handler.output_exchange, routing_key=tenant_id)
|
||||
callback = await on_response_message_callback(storage)
|
||||
await response_queue.consume(callback=callback)
|
||||
|
||||
await asyncio.sleep(5) # Wait for message processing
|
||||
|
||||
# Test tenant deletion
|
||||
delete_message = {"tenantId": tenant_id}
|
||||
await handler.tenant_exchange.publish(
|
||||
Message(body=json.dumps(delete_message).encode()), routing_key="tenant.delete"
|
||||
)
|
||||
logger.info(f"Sent delete tenant message for {tenant_id}")
|
||||
await asyncio.sleep(2) # Wait for queue deletion
|
||||
|
||||
await handler.connection.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_rabbitmq_handler())
|
||||
@ -2,14 +2,14 @@ import asyncio
|
||||
import gzip
|
||||
import json
|
||||
import time
|
||||
from aio_pika.abc import AbstractIncomingMessage
|
||||
from operator import itemgetter
|
||||
|
||||
from aio_pika.abc import AbstractIncomingMessage
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
|
||||
from pyinfra.queue.async_tenants import AsyncQueueManager
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
|
||||
from pyinfra.storage.storages.s3 import S3Storage
|
||||
from pyinfra.storage.storages.s3 import S3Storage, get_s3_storage_from_settings
|
||||
|
||||
settings = load_settings(local_pyinfra_root_path / "config/")
|
||||
|
||||
@ -57,6 +57,7 @@ def on_message_callback(storage: S3Storage):
|
||||
result = storage.get_object(f"{tenant_id}/{dossier_id}/{file_id}.{suffix}")
|
||||
result = json.loads(gzip.decompress(result))
|
||||
logger.info(f"Contents of result on storage: {result}")
|
||||
|
||||
return on_message
|
||||
|
||||
|
||||
@ -98,9 +99,7 @@ async def send_service_request(queue_manager: AsyncQueueManager, tenant_id: str)
|
||||
if __name__ == "__main__":
|
||||
# tenant_ids = ["a", "b", "c", "d"]
|
||||
|
||||
queue_manager = AsyncQueueManager(settings)
|
||||
|
||||
# asyncio.run(send_tenant_event(queue_manager, "test", "create"))
|
||||
|
||||
asyncio.run(send_service_request(queue_manager,"redaction"))
|
||||
# asyncio.run(send_tenant_event(AsyncQueueManager(settings), "test_1", "create"))
|
||||
|
||||
asyncio.run(send_service_request(AsyncQueueManager(settings), "redaction"))
|
||||
# asyncio.run(consume_service_request(AsyncQueueManager(settings),"redaction"))
|
||||
|
||||
@ -3,6 +3,7 @@ import json
|
||||
import time
|
||||
from operator import itemgetter
|
||||
from threading import Thread
|
||||
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.config.loader import load_settings, local_pyinfra_root_path
|
||||
@ -67,9 +68,7 @@ def send_service_request(tenant_id: str):
|
||||
|
||||
storage = get_s3_storage_from_settings(settings)
|
||||
|
||||
for method_frame, properties, body in queue_manager.channel.consume(
|
||||
queue=queue_name, inactivity_timeout=15
|
||||
):
|
||||
for method_frame, properties, body in queue_manager.channel.consume(queue=queue_name, inactivity_timeout=15):
|
||||
if not body:
|
||||
break
|
||||
response = json.loads(body)
|
||||
@ -87,13 +86,15 @@ def send_service_request(tenant_id: str):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tenant_ids = ["a", "b", "c", "d"]
|
||||
import uuid
|
||||
|
||||
for tenant in tenant_ids:
|
||||
unique_ids = [str(uuid.uuid4()) for _ in range(100)]
|
||||
|
||||
for tenant in unique_ids:
|
||||
send_tenant_event(tenant_id=tenant, event_type="create")
|
||||
|
||||
for tenant in tenant_ids:
|
||||
send_service_request(tenant_id=tenant)
|
||||
# for tenant in tenant_ids:
|
||||
# send_service_request(tenant_id=tenant)
|
||||
|
||||
for tenant in tenant_ids:
|
||||
send_tenant_event(tenant_id=tenant, event_type="delete")
|
||||
# for tenant in tenant_ids:
|
||||
# send_tenant_event(tenant_id=tenant, event_type="delete")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user