feat: add async webserver for probes

This commit is contained in:
Jonathan Kössler 2024-08-21 16:24:20 +02:00
parent 2970823cc1
commit fa44f36088
3 changed files with 63 additions and 30 deletions

View File

@ -15,10 +15,24 @@ from pyinfra.webserver.prometheus import (
) )
from pyinfra.webserver.utils import ( from pyinfra.webserver.utils import (
add_health_check_endpoint, add_health_check_endpoint,
create_webserver_task_from_settings,
create_webserver_thread_from_settings, create_webserver_thread_from_settings,
) )
async def run_async_tasks(manager, webserver):
"""Run the webserver and the async queue manager concurrently."""
# Start the web server as an async task
webserver_task = asyncio.create_task(webserver)
# Start the async queue manager
queue_manager_task = asyncio.create_task(manager.run())
# Wait for both tasks to complete (typically, they run indefinitely)
await asyncio.gather(webserver_task, queue_manager_task)
def start_standard_queue_consumer( def start_standard_queue_consumer(
callback: Callback, callback: Callback,
settings: Dynaconf, settings: Dynaconf,
@ -72,12 +86,12 @@ def start_standard_queue_consumer(
app = add_health_check_endpoint(app, manager.is_ready) app = add_health_check_endpoint(app, manager.is_ready)
webserver_thread = create_webserver_thread_from_settings(app, settings)
webserver_thread.start()
if isinstance(manager, AsyncQueueManager): if isinstance(manager, AsyncQueueManager):
asyncio.run(manager.run()) webserver = create_webserver_task_from_settings(app, settings)
asyncio.run(run_async_tasks(manager, webserver))
elif isinstance(manager, QueueManager): elif isinstance(manager, QueueManager):
webserver = create_webserver_thread_from_settings(app, settings)
webserver.start()
manager.start_consuming(callback) manager.start_consuming(callback)
else: else:
logger.warning(f"Behavior for type {type(manager)} is not defined") logger.warning(f"Behavior for type {type(manager)} is not defined")

View File

@ -1,7 +1,6 @@
import asyncio import asyncio
import json import json
import signal import signal
import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Set from typing import Any, Callable, Dict, Set
@ -101,7 +100,7 @@ class AsyncQueueManager:
"x-dead-letter-exchange": "", "x-dead-letter-exchange": "",
"x-dead-letter-routing-key": self.config.service_dead_letter_queue_name, "x-dead-letter-routing-key": self.config.service_dead_letter_queue_name,
"x-expires": self.config.queue_expiration_time, "x-expires": self.config.queue_expiration_time,
"x-max-priority": 2, # "x-max-priority": 2,
}, },
) )
await self.tenant_exchange_queue.bind(self.tenant_exchange, routing_key="tenant.*") await self.tenant_exchange_queue.bind(self.tenant_exchange, routing_key="tenant.*")
@ -110,34 +109,40 @@ class AsyncQueueManager:
) )
async def process_tenant_message(self, message: AbstractIncomingMessage) -> None: async def process_tenant_message(self, message: AbstractIncomingMessage) -> None:
async with message.process(): try:
message_body = json.loads(message.body.decode()) async with message.process():
logger.debug(f"Tenant message received: {message_body}") message_body = json.loads(message.body.decode())
tenant_id = message_body["tenantId"] logger.debug(f"Tenant message received: {message_body}")
routing_key = message.routing_key tenant_id = message_body["tenantId"]
routing_key = message.routing_key
if routing_key == "tenant.created": if routing_key == "tenant.created":
await self.create_tenant_queues(tenant_id) await self.create_tenant_queues(tenant_id)
elif routing_key == "tenant.delete": elif routing_key == "tenant.delete":
await self.delete_tenant_queues(tenant_id) await self.delete_tenant_queues(tenant_id)
except Exception as e:
logger.error(e, exc_info=True)
async def create_tenant_queues(self, tenant_id: str) -> None: async def create_tenant_queues(self, tenant_id: str) -> None:
queue_name = f"{self.config.input_queue_prefix}_{tenant_id}" queue_name = f"{self.config.input_queue_prefix}_{tenant_id}"
logger.info(f"Declaring queue: {queue_name}") logger.info(f"Declaring queue: {queue_name}")
input_queue = await self.channel.declare_queue( try:
queue_name, input_queue = await self.channel.declare_queue(
durable=True, queue_name,
arguments={ durable=True,
"x-dead-letter-exchange": "", # arguments={
"x-dead-letter-routing-key": self.config.service_dead_letter_queue_name, # "x-dead-letter-exchange": "",
"x-expires": self.config.queue_expiration_time, # "x-dead-letter-routing-key": self.config.service_dead_letter_queue_name,
"x-max-priority": 2, # "x-expires": self.config.queue_expiration_time,
}, # "x-max-priority": 2,
) # },
await input_queue.bind(self.input_exchange, routing_key=tenant_id) )
self.consumer_tags[tenant_id] = await input_queue.consume(self.process_input_message) await input_queue.bind(self.input_exchange, routing_key=tenant_id)
self.tenant_queues[tenant_id] = input_queue self.consumer_tags[tenant_id] = await input_queue.consume(self.process_input_message)
logger.info(f"Created queues for tenant {tenant_id}") self.tenant_queues[tenant_id] = input_queue
logger.info(f"Created and started consuming queue for tenant {tenant_id}")
except Exception as e:
logger.error(e, exc_info=True)
async def delete_tenant_queues(self, tenant_id: str) -> None: async def delete_tenant_queues(self, tenant_id: str) -> None:
if tenant_id in self.tenant_queues: if tenant_id in self.tenant_queues:
@ -271,7 +276,7 @@ class AsyncQueueManager:
await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"]) await self.tenant_exchange_queue.cancel(self.consumer_tags["tenant_exchange_queue"])
while self.message_count != 0: while self.message_count != 0:
logger.debug(f"Messages are still being processed: {self.message_count=} ") logger.debug(f"Messages are still being processed: {self.message_count=} ")
time.sleep(2) await asyncio.sleep(2)
await self.channel.close() await self.channel.close()
if self.connection: if self.connection:
await self.connection.close() await self.connection.close()

View File

@ -1,3 +1,4 @@
import asyncio
import inspect import inspect
import logging import logging
import threading import threading
@ -25,6 +26,19 @@ def create_webserver_thread(app: FastAPI, port: int, host: str) -> threading.Thr
return thread return thread
async def create_webserver_task_from_settings(app: FastAPI, settings: Dynaconf) -> asyncio.Task:
validate_settings(settings, validators=webserver_validators)
return await create_webserver_task(app=app, port=settings.webserver.port, host=settings.webserver.host)
async def create_webserver_task(app: FastAPI, port: int, host: str) -> asyncio.Task:
"""Creates an asyncio task that runs a FastAPI webserver."""
config = uvicorn.Config(app=app, host=host, port=port, log_level=logging.WARNING)
server = uvicorn.Server(config)
task = asyncio.create_task(server.serve())
return task
HealthFunction = Callable[[], bool] HealthFunction = Callable[[], bool]