125 lines
5.0 KiB
Python
125 lines
5.0 KiB
Python
import asyncio
|
|
import sys
|
|
|
|
from aiormq.exceptions import AMQPConnectionError
|
|
from dynaconf import Dynaconf
|
|
from fastapi import FastAPI
|
|
from kn_utils.logging import logger
|
|
|
|
from pyinfra.config.loader import get_pyinfra_validators, validate_settings
|
|
from pyinfra.queue.async_manager import AsyncQueueManager, RabbitMQConfig
|
|
from pyinfra.queue.callback import Callback
|
|
from pyinfra.queue.manager import QueueManager
|
|
from pyinfra.utils.opentelemetry import instrument_app, instrument_pika, setup_trace
|
|
from pyinfra.webserver.prometheus import (
|
|
add_prometheus_endpoint,
|
|
make_prometheus_processing_time_decorator_from_settings,
|
|
)
|
|
from pyinfra.webserver.utils import (
|
|
add_health_check_endpoint,
|
|
create_webserver_thread_from_settings,
|
|
run_async_webserver,
|
|
)
|
|
|
|
|
|
async def run_async_queues(manager: AsyncQueueManager, app, port, host):
|
|
"""Run the async webserver and the async queue manager concurrently."""
|
|
queue_task = None
|
|
webserver_task = None
|
|
try:
|
|
queue_task = asyncio.create_task(manager.run(), name="queues")
|
|
webserver_task = asyncio.create_task(run_async_webserver(app, port, host), name="webserver")
|
|
await asyncio.gather(queue_task, webserver_task)
|
|
except asyncio.CancelledError:
|
|
logger.info("Main task was cancelled, initiating shutdown.")
|
|
except AMQPConnectionError as e:
|
|
logger.warning(f"AMQPConnectionError: {e} - shutting down.")
|
|
except Exception as e:
|
|
logger.error(f"An error occurred while running async queues: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
finally:
|
|
logger.info("Signal received, shutting down...")
|
|
if queue_task and not queue_task.done():
|
|
queue_task.cancel()
|
|
if webserver_task and not webserver_task.done():
|
|
webserver_task.cancel()
|
|
|
|
await manager.shutdown()
|
|
|
|
await asyncio.gather(queue_task, webserver_task, return_exceptions=True)
|
|
|
|
|
|
def start_standard_queue_consumer(
|
|
callback: Callback,
|
|
settings: Dynaconf,
|
|
app: FastAPI = None,
|
|
):
|
|
"""Default serving logic for research services.
|
|
|
|
Supplies /health, /ready and /prometheus endpoints (if enabled). The callback is monitored for processing time per
|
|
message. Also traces the queue messages via openTelemetry (if enabled).
|
|
Workload is received via queue messages and processed by the callback function (see pyinfra.queue.callback for
|
|
callbacks).
|
|
"""
|
|
validate_settings(settings, get_pyinfra_validators())
|
|
|
|
logger.info("Starting webserver and queue consumer...")
|
|
|
|
app = app or FastAPI()
|
|
|
|
if settings.metrics.prometheus.enabled:
|
|
logger.info("Prometheus metrics enabled.")
|
|
app = add_prometheus_endpoint(app)
|
|
callback = make_prometheus_processing_time_decorator_from_settings(settings)(callback)
|
|
|
|
if settings.tracing.enabled:
|
|
setup_trace(settings)
|
|
|
|
instrument_pika(dynamic_queues=settings.dynamic_tenant_queues.enabled)
|
|
instrument_app(app)
|
|
|
|
if settings.dynamic_tenant_queues.enabled:
|
|
logger.info("Dynamic tenant queues enabled. Running async queues.")
|
|
config = RabbitMQConfig(
|
|
host=settings.rabbitmq.host,
|
|
port=settings.rabbitmq.port,
|
|
username=settings.rabbitmq.username,
|
|
password=settings.rabbitmq.password,
|
|
heartbeat=settings.rabbitmq.heartbeat,
|
|
input_queue_prefix=settings.rabbitmq.service_request_queue_prefix,
|
|
tenant_event_queue_suffix=settings.rabbitmq.tenant_event_queue_suffix,
|
|
tenant_exchange_name=settings.rabbitmq.tenant_exchange_name,
|
|
service_request_exchange_name=settings.rabbitmq.service_request_exchange_name,
|
|
service_response_exchange_name=settings.rabbitmq.service_response_exchange_name,
|
|
service_dead_letter_queue_name=settings.rabbitmq.service_dlq_name,
|
|
queue_expiration_time=settings.rabbitmq.queue_expiration_time,
|
|
pod_name=settings.kubernetes.pod_name,
|
|
)
|
|
manager = AsyncQueueManager(
|
|
config=config,
|
|
tenant_service_url=settings.storage.tenant_server.endpoint,
|
|
message_processor=callback,
|
|
max_concurrent_tasks=(
|
|
settings.asyncio.max_concurrent_tasks if hasattr(settings.asyncio, "max_concurrent_tasks") else 10
|
|
),
|
|
)
|
|
else:
|
|
logger.info("Dynamic tenant queues disabled. Running sync queues.")
|
|
manager = QueueManager(settings)
|
|
|
|
app = add_health_check_endpoint(app, manager.is_ready)
|
|
|
|
if isinstance(manager, AsyncQueueManager):
|
|
asyncio.run(run_async_queues(manager, app, port=settings.webserver.port, host=settings.webserver.host))
|
|
|
|
elif isinstance(manager, QueueManager):
|
|
webserver = create_webserver_thread_from_settings(app, settings)
|
|
webserver.start()
|
|
try:
|
|
manager.start_consuming(callback)
|
|
except Exception as e:
|
|
logger.error(f"An error occurred while consuming messages: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
else:
|
|
logger.warning(f"Behavior for type {type(manager)} is not defined")
|