From fa44f36088d80482eb1205bff0e28dc96ad6d8ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20K=C3=B6ssler?= Date: Wed, 21 Aug 2024 16:24:20 +0200 Subject: [PATCH] feat: add async webserver for probes --- pyinfra/examples.py | 22 ++++++++++--- pyinfra/queue/async_manager.py | 57 ++++++++++++++++++---------------- pyinfra/webserver/utils.py | 14 +++++++++ 3 files changed, 63 insertions(+), 30 deletions(-) diff --git a/pyinfra/examples.py b/pyinfra/examples.py index 1132259..69f83dc 100644 --- a/pyinfra/examples.py +++ b/pyinfra/examples.py @@ -15,10 +15,24 @@ from pyinfra.webserver.prometheus import ( ) from pyinfra.webserver.utils import ( add_health_check_endpoint, + create_webserver_task_from_settings, create_webserver_thread_from_settings, ) +async def run_async_tasks(manager, webserver): + """Run the webserver and the async queue manager concurrently.""" + + # Start the web server as an async task + webserver_task = asyncio.create_task(webserver) + + # Start the async queue manager + queue_manager_task = asyncio.create_task(manager.run()) + + # Wait for both tasks to complete (typically, they run indefinitely) + await asyncio.gather(webserver_task, queue_manager_task) + + def start_standard_queue_consumer( callback: Callback, settings: Dynaconf, @@ -72,12 +86,12 @@ def start_standard_queue_consumer( app = add_health_check_endpoint(app, manager.is_ready) - webserver_thread = create_webserver_thread_from_settings(app, settings) - webserver_thread.start() - if isinstance(manager, AsyncQueueManager): - asyncio.run(manager.run()) + webserver = create_webserver_task_from_settings(app, settings) + asyncio.run(run_async_tasks(manager, webserver)) elif isinstance(manager, QueueManager): + webserver = create_webserver_thread_from_settings(app, settings) + webserver.start() manager.start_consuming(callback) else: logger.warning(f"Behavior for type {type(manager)} is not defined") diff --git a/pyinfra/queue/async_manager.py b/pyinfra/queue/async_manager.py index 9031e24..92b8bb5 100644 --- a/pyinfra/queue/async_manager.py +++ b/pyinfra/queue/async_manager.py @@ -1,7 +1,6 @@ import asyncio import json import signal -import time from dataclasses import dataclass, field from typing import Any, Callable, Dict, Set @@ -101,7 +100,7 @@ class AsyncQueueManager: "x-dead-letter-exchange": "", "x-dead-letter-routing-key": self.config.service_dead_letter_queue_name, "x-expires": self.config.queue_expiration_time, - "x-max-priority": 2, + # "x-max-priority": 2, }, ) await self.tenant_exchange_queue.bind(self.tenant_exchange, routing_key="tenant.*") @@ -110,34 +109,40 @@ class AsyncQueueManager: ) async def process_tenant_message(self, message: AbstractIncomingMessage) -> None: - async with message.process(): - message_body = json.loads(message.body.decode()) - logger.debug(f"Tenant message received: {message_body}") - tenant_id = message_body["tenantId"] - routing_key = message.routing_key + try: + async with message.process(): + message_body = json.loads(message.body.decode()) + logger.debug(f"Tenant message received: {message_body}") + tenant_id = message_body["tenantId"] + routing_key = message.routing_key - if routing_key == "tenant.created": - await self.create_tenant_queues(tenant_id) - elif routing_key == "tenant.delete": - await self.delete_tenant_queues(tenant_id) + if routing_key == "tenant.created": + await self.create_tenant_queues(tenant_id) + elif routing_key == "tenant.delete": + await self.delete_tenant_queues(tenant_id) + except Exception as e: + logger.error(e, exc_info=True) async def create_tenant_queues(self, tenant_id: str) -> None: queue_name = f"{self.config.input_queue_prefix}_{tenant_id}" logger.info(f"Declaring queue: {queue_name}") - input_queue = await self.channel.declare_queue( - queue_name, - durable=True, - arguments={ - "x-dead-letter-exchange": "", - "x-dead-letter-routing-key": self.config.service_dead_letter_queue_name, - "x-expires": self.config.queue_expiration_time, - "x-max-priority": 2, - }, - ) - await input_queue.bind(self.input_exchange, routing_key=tenant_id) - self.consumer_tags[tenant_id] = await input_queue.consume(self.process_input_message) - self.tenant_queues[tenant_id] = input_queue - logger.info(f"Created queues for tenant {tenant_id}") + try: + input_queue = await self.channel.declare_queue( + queue_name, + durable=True, + # arguments={ + # "x-dead-letter-exchange": "", + # "x-dead-letter-routing-key": self.config.service_dead_letter_queue_name, + # "x-expires": self.config.queue_expiration_time, + # "x-max-priority": 2, + # }, + ) + await input_queue.bind(self.input_exchange, routing_key=tenant_id) + self.consumer_tags[tenant_id] = await input_queue.consume(self.process_input_message) + self.tenant_queues[tenant_id] = input_queue + logger.info(f"Created and started consuming queue for tenant {tenant_id}") + except Exception as e: + logger.error(e, exc_info=True) async def delete_tenant_queues(self, tenant_id: str) -> None: if tenant_id in self.tenant_queues: @@ -271,7 +276,7 @@ class AsyncQueueManager: await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"]) while self.message_count != 0: logger.debug(f"Messages are still being processed: {self.message_count=} ") - time.sleep(2) + await asyncio.sleep(2) await self.channel.close() if self.connection: await self.connection.close() diff --git a/pyinfra/webserver/utils.py b/pyinfra/webserver/utils.py index 9a2a438..e8a5d4c 100644 --- a/pyinfra/webserver/utils.py +++ b/pyinfra/webserver/utils.py @@ -1,3 +1,4 @@ +import asyncio import inspect import logging import threading @@ -25,6 +26,19 @@ def create_webserver_thread(app: FastAPI, port: int, host: str) -> threading.Thr return thread +async def create_webserver_task_from_settings(app: FastAPI, settings: Dynaconf) -> asyncio.Task: + validate_settings(settings, validators=webserver_validators) + return await create_webserver_task(app=app, port=settings.webserver.port, host=settings.webserver.host) + + +async def create_webserver_task(app: FastAPI, port: int, host: str) -> asyncio.Task: + """Creates an asyncio task that runs a FastAPI webserver.""" + config = uvicorn.Config(app=app, host=host, port=port, log_level=logging.WARNING) + server = uvicorn.Server(config) + task = asyncio.create_task(server.serve()) + return task + + HealthFunction = Callable[[], bool]