fix: add semaphore to AsyncQueueManager to limit concurrent tasks

This commit is contained in:
Francisco Schulz 2024-09-23 15:19:40 +02:00
parent 8e21b2144c
commit 8ec13502a9
4 changed files with 40 additions and 23 deletions

View File

@ -1 +1 @@
3.10.12
3.10

View File

@ -1,14 +1,20 @@
import asyncio
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import sys
from dynaconf import Dynaconf
from fastapi import FastAPI
from kn_utils.logging import logger
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from pyinfra.config.loader import get_pyinfra_validators, validate_settings
from pyinfra.queue.async_manager import AsyncQueueManager, RabbitMQConfig
from pyinfra.queue.manager import QueueManager
from pyinfra.queue.callback import Callback
from pyinfra.queue.manager import QueueManager
from pyinfra.utils.opentelemetry import instrument_app, instrument_pika, setup_trace
from pyinfra.webserver.prometheus import (
add_prometheus_endpoint,
@ -34,6 +40,9 @@ async def run_async_queues(manager, app, port, host):
await run_async_webserver(app, port, host)
except asyncio.CancelledError:
logger.info("Main task is cancelled.")
except Exception as e:
logger.error(f"An error occurred while running async queues: {e}", exc_info=True)
sys.exit(1)
finally:
logger.info("Signal received, shutting down...")
await manager.shutdown()
@ -84,6 +93,7 @@ def start_standard_queue_consumer(
instrument_app(app)
if settings.dynamic_tenant_queues.enabled:
logger.info("Dynamic tenant queues enabled. Running async queues.")
config = RabbitMQConfig(
host=settings.rabbitmq.host,
port=settings.rabbitmq.port,
@ -100,9 +110,15 @@ def start_standard_queue_consumer(
pod_name=settings.kubernetes.pod_name,
)
manager = AsyncQueueManager(
config=config, tenant_service_url=settings.storage.tenant_server.endpoint, message_processor=callback
config=config,
tenant_service_url=settings.storage.tenant_server.endpoint,
message_processor=callback,
max_concurrent_tasks=(
settings.asyncio.max_concurrent_tasks if hasattr(settings.asyncio, "max_concurrent_tasks") else 10
),
)
else:
logger.info("Dynamic tenant queues disabled. Running sync queues.")
manager = QueueManager(settings)
app = add_health_check_endpoint(app, manager.is_ready)
@ -116,9 +132,7 @@ def start_standard_queue_consumer(
try:
manager.start_consuming(callback)
except Exception as e:
logger.error(f"An error occurred while consuming messages: {e}")
# Optionally, you can choose to exit here if you want to restart the process
# import sys
# sys.exit(1)
logger.error(f"An error occurred while consuming messages: {e}", exc_info=True)
sys.exit(1)
else:
logger.warning(f"Behavior for type {type(manager)} is not defined")

View File

@ -15,17 +15,16 @@ from aio_pika.abc import (
AbstractIncomingMessage,
AbstractQueue,
)
from aio_pika.exceptions import AMQPConnectionError, ChannelInvalidStateError
from aiormq.exceptions import AMQPConnectionError
from kn_utils.logging import logger
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
wait_exponential,
retry_if_exception_type,
wait_exponential_jitter,
)
from aio_pika.exceptions import AMQPConnectionError, ChannelInvalidStateError
from aiormq.exceptions import AMQPConnectionError
@dataclass
@ -62,10 +61,12 @@ class AsyncQueueManager:
config: RabbitMQConfig,
tenant_service_url: str,
message_processor: Callable[[Dict[str, Any]], Dict[str, Any]],
max_concurrent_tasks: int = 10,
):
self.config = config
self.tenant_service_url = tenant_service_url
self.message_processor = message_processor
self.semaphore = asyncio.Semaphore(max_concurrent_tasks)
self.connection: AbstractConnection | None = None
self.channel: AbstractChannel | None = None
@ -178,11 +179,14 @@ class AsyncQueueManager:
async def process_input_message(self, message: IncomingMessage) -> None:
async def process_message_body_and_await_result(unpacked_message_body):
loop = asyncio.get_running_loop()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.info("Processing payload in a separate thread.")
result = await loop.run_in_executor(thread_pool_executor, self.message_processor, unpacked_message_body)
return result
async with self.semaphore:
loop = asyncio.get_running_loop()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
logger.info("Processing payload in a separate thread.")
result = await loop.run_in_executor(
thread_pool_executor, self.message_processor, unpacked_message_body
)
return result
async with message.process(ignore_processed=True):
if message.redelivered:
@ -222,14 +226,13 @@ class AsyncQueueManager:
except json.JSONDecodeError:
await message.nack(requeue=False)
logger.error(f"Invalid JSON in input message: {message.body}")
logger.error(f"Invalid JSON in input message: {message.body}", exc_info=True)
except FileNotFoundError as e:
logger.warning(f"{e}, declining message with {message.delivery_tag=}.")
logger.warning(f"{e}, declining message with {message.delivery_tag=}.", exc_info=True)
await message.nack(requeue=False)
except Exception as e:
await message.nack(requeue=False)
logger.error(f"Error processing input message: {e}", exc_info=True)
raise
finally:
self.message_count -= 1
@ -269,7 +272,7 @@ class AsyncQueueManager:
try:
active_tenants = await self.fetch_active_tenants()
except (aiohttp.ClientResponseError, aiohttp.ClientConnectorError):
logger.warning("API calls to tenant server failed. No tenant queues initialized.")
logger.warning("API calls to tenant server failed. No tenant queues initialized.", exc_info=True)
active_tenants = set()
for tenant_id in active_tenants:
await self.create_tenant_queues(tenant_id)
@ -283,7 +286,7 @@ class AsyncQueueManager:
logger.info("RabbitMQ handler is running. Press CTRL+C to exit.")
except AMQPConnectionError as e:
logger.error(f"Failed to establish connection to RabbitMQ: {e}")
logger.error(f"Failed to establish connection to RabbitMQ: {e}", exc_info=True)
# TODO: implement a custom exception handling strategy here
except asyncio.CancelledError:
logger.warning("Operation cancelled.")

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "pyinfra"
version = "3.2.7"
version = "3.2.8"
description = ""
authors = ["Team Research <research@knecon.com>"]
license = "All rights reseverd"