From f2018f9c86e4ab389a05e9417c67010a59f4f634 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20K=C3=B6ssler?= Date: Fri, 23 Aug 2024 16:56:24 +0200 Subject: [PATCH] fix: process message in thread in event loop --- pyinfra/examples.py | 19 ++++++++++--------- pyinfra/queue/async_manager.py | 18 +++++++++--------- pyinfra/webserver/utils.py | 16 +++++----------- 3 files changed, 24 insertions(+), 29 deletions(-) diff --git a/pyinfra/examples.py b/pyinfra/examples.py index 558d88c..21ced10 100644 --- a/pyinfra/examples.py +++ b/pyinfra/examples.py @@ -15,17 +15,18 @@ from pyinfra.webserver.prometheus import ( ) from pyinfra.webserver.utils import ( add_health_check_endpoint, - create_webserver_task_from_settings, create_webserver_thread_from_settings, + run_async_webserver, ) -async def run_async_queues(manager, webserver): - """Run the webserver and the async queue manager concurrently.""" - webserver_task = asyncio.create_task(webserver) - queue_manager_task = asyncio.create_task(manager.run()) - - await asyncio.gather(webserver_task, queue_manager_task) +async def run_async_queues(manager, app, port, host): + """Run the async webserver and the async queue manager concurrently.""" + try: + await manager.run() + await run_async_webserver(app, port, host) + except asyncio.CancelledError: + logger.info("Main task is cancelled.") def start_standard_queue_consumer( @@ -82,8 +83,8 @@ def start_standard_queue_consumer( app = add_health_check_endpoint(app, manager.is_ready) if isinstance(manager, AsyncQueueManager): - webserver = create_webserver_task_from_settings(app, settings) - asyncio.run(run_async_queues(manager, webserver)) + asyncio.run(run_async_queues(manager, app, port=settings.webserver.port, host=settings.webserver.host)) + elif isinstance(manager, QueueManager): webserver = create_webserver_thread_from_settings(app, settings) webserver.start() diff --git a/pyinfra/queue/async_manager.py b/pyinfra/queue/async_manager.py index 92b8bb5..2f1fc26 100644 --- a/pyinfra/queue/async_manager.py +++ b/pyinfra/queue/async_manager.py @@ -1,6 +1,8 @@ import asyncio +import concurrent.futures import json import signal +import sys from dataclasses import dataclass, field from typing import Any, Callable, Dict, Set @@ -79,7 +81,7 @@ class AsyncQueueManager: async def is_ready(self) -> bool: await self.connect() - return self.channel.is_initialized + return self.connection is not None and not self.connection.is_closed async def setup_exchanges(self) -> None: self.tenant_exchange = await self.channel.declare_exchange( @@ -100,7 +102,6 @@ class AsyncQueueManager: "x-dead-letter-exchange": "", "x-dead-letter-routing-key": self.config.service_dead_letter_queue_name, "x-expires": self.config.queue_expiration_time, - # "x-max-priority": 2, }, ) await self.tenant_exchange_queue.bind(self.tenant_exchange, routing_key="tenant.*") @@ -130,12 +131,6 @@ class AsyncQueueManager: input_queue = await self.channel.declare_queue( queue_name, durable=True, - # arguments={ - # "x-dead-letter-exchange": "", - # "x-dead-letter-routing-key": self.config.service_dead_letter_queue_name, - # "x-expires": self.config.queue_expiration_time, - # "x-max-priority": 2, - # }, ) await input_queue.bind(self.input_exchange, routing_key=tenant_id) self.consumer_tags[tenant_id] = await input_queue.consume(self.process_input_message) @@ -154,7 +149,11 @@ class AsyncQueueManager: async def process_input_message(self, message: IncomingMessage) -> None: async def process_message_body_and_await_result(unpacked_message_body): - return self.message_processor(unpacked_message_body) + loop = asyncio.get_running_loop() + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_pool_executor: + logger.info("Processing payload in a separate thread.") + result = await loop.run_in_executor(thread_pool_executor, self.message_processor, unpacked_message_body) + return result async with message.process(ignore_processed=True): if message.redelivered: @@ -281,3 +280,4 @@ class AsyncQueueManager: if self.connection: await self.connection.close() logger.info("RabbitMQ handler shut down successfully.") + sys.exit(0) diff --git a/pyinfra/webserver/utils.py b/pyinfra/webserver/utils.py index e8a5d4c..b28d4cc 100644 --- a/pyinfra/webserver/utils.py +++ b/pyinfra/webserver/utils.py @@ -1,4 +1,3 @@ -import asyncio import inspect import logging import threading @@ -26,17 +25,12 @@ def create_webserver_thread(app: FastAPI, port: int, host: str) -> threading.Thr return thread -async def create_webserver_task_from_settings(app: FastAPI, settings: Dynaconf) -> asyncio.Task: - validate_settings(settings, validators=webserver_validators) - return await create_webserver_task(app=app, port=settings.webserver.port, host=settings.webserver.host) - - -async def create_webserver_task(app: FastAPI, port: int, host: str) -> asyncio.Task: - """Creates an asyncio task that runs a FastAPI webserver.""" - config = uvicorn.Config(app=app, host=host, port=port, log_level=logging.WARNING) +async def run_async_webserver(app: FastAPI, port: int, host: str): + """Run the FastAPI web server async.""" + config = uvicorn.Config(app, host=host, port=port, log_level=logging.WARNING) server = uvicorn.Server(config) - task = asyncio.create_task(server.serve()) - return task + + await server.serve() HealthFunction = Callable[[], bool]