diff --git a/scripts/test_async_class.py b/scripts/test_async_class.py index 95af16b..8a66d8f 100644 --- a/scripts/test_async_class.py +++ b/scripts/test_async_class.py @@ -1,4 +1,6 @@ import asyncio +import gzip +from operator import itemgetter from typing import Any, Callable, Dict, Set import aiohttp from aio_pika import connect_robust, ExchangeType, Message, IncomingMessage @@ -6,7 +8,7 @@ from aio_pika.abc import AbstractIncomingMessage import json from logging import getLogger from pyinfra.config.loader import load_settings, local_pyinfra_root_path -import requests +from pyinfra.storage.storages.s3 import get_s3_storage_from_settings logger = getLogger(__name__) logger.setLevel("DEBUG") @@ -65,7 +67,7 @@ class RabbitMQHandler: message_body = json.loads(message.body.decode()) print(message_body) logger.debug(f"input message: {message_body}") - tenant_id = message_body["queue_name"] + tenant_id = message_body["tenantId"] routing_key = message.routing_key if routing_key == "tenant.create": @@ -125,11 +127,7 @@ class RabbitMQHandler: if result: # Publish the result to the output exchange - await self.output_exchange.publish( - Message(body=json.dumps(result).encode(), headers=filtered_message_headers), - routing_key=tenant_id, - ) - logger.info(f"Published result to queue {tenant_id}.") + self.publish_to_output_exchange(tenant_id, result, filtered_message_headers) except json.JSONDecodeError: logger.error(f"Invalid JSON in input message: {message.body}") @@ -142,6 +140,13 @@ class RabbitMQHandler: # Message will be nacked and sent to dead-letter queue raise + async def publish_to_output_exchange(self, tenant_id: str, result: Dict[str, Any], headers: Dict[str, Any]): + await self.output_exchange.publish( + Message(body=json.dumps(result).encode(), headers=headers), + routing_key=tenant_id, + ) + logger.info(f"Published result to queue {tenant_id}.") + async def fetch_active_tenants(self) -> Set[str]: try: async with aiohttp.ClientSession() as session: @@ -181,9 +186,114 @@ async def main(): "client_properties": {"heartbeat": settings.rabbitmq.heartbeat}, } tenant_service_url = "http://localhost:8080/internal/tenants" - handler = RabbitMQHandler(connection_params, tenant_service_url) + handler = RabbitMQHandler(connection_params, tenant_service_url, dummy_message_processor) await handler.run() +######################################################################## +def upload_json_and_make_message_body(tenant_id: str) -> Dict[str, Any]: + dossier_id, file_id, suffix = "dossier", "file", "json.gz" + content = { + "numberOfPages": 7, + "sectionTexts": "data", + } + + object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}" + data = gzip.compress(json.dumps(content).encode("utf-8")) + + storage = get_s3_storage_from_settings(settings) + if not storage.has_bucket(): + storage.make_bucket() + storage.put_object(object_name, data) + + message_body = { + "tenantId": tenant_id, + "dossierId": dossier_id, + "fileId": file_id, + "targetFileExtension": suffix, + "responseFileExtension": f"result.{suffix}", + } + return message_body + + +def tenant_event_message(tenant_id: str) -> Dict[str, str]: + return {"tenantId": tenant_id} + + +async def dummy_message_processor(message: Dict[str, Any]) -> Dict[str, Any]: + logger.info(f"Processing message: {message}") + await asyncio.sleep(1) # Simulate processing time + + storage = get_s3_storage_from_settings(settings) + tenant_id, dossier_id, file_id = itemgetter("tenantId", "dossierId", "fileId")(message) + suffix = message["responseFileExtension"] + + # Simulate processing by modifying the original content + object_name = f"{tenant_id}/{dossier_id}/{file_id}.{message['targetFileExtension']}" + original_content = json.loads(gzip.decompress(storage.get_object(object_name))) + processed_content = { + "processedPages": original_content["numberOfPages"], + "processedSectionTexts": f"Processed: {original_content['sectionTexts']}", + } + + # Save processed content + processed_object_name = f"{tenant_id}/{dossier_id}/{file_id}.{suffix}" + processed_data = gzip.compress(json.dumps(processed_content).encode("utf-8")) + storage.put_object(processed_object_name, processed_data) + + processed_message = message.copy() + processed_message["processed"] = True + processed_message["processor_message"] = "This message was processed by the dummy processor" + + logger.info(f"Finished processing message. Result: {processed_message}") + return processed_message + + +async def test_rabbitmq_handler(): + connection_params = { + "host": settings.rabbitmq.host, + "port": settings.rabbitmq.port, + "login": settings.rabbitmq.username, + "password": settings.rabbitmq.password, + "client_properties": {"heartbeat": settings.rabbitmq.heartbeat}, + } + tenant_service_url = "http://localhost:8080/internal/tenants" + handler = RabbitMQHandler(connection_params, tenant_service_url, dummy_message_processor) + + await handler.connect() + await handler.setup_tenant_queue() + + # Test tenant creation + tenant_id = "test_tenant" + create_message = tenant_event_message(tenant_id) + await handler.tenant_exchange.publish( + Message(body=json.dumps(create_message).encode()), routing_key="tenant.create" + ) + logger.info(f"Sent create tenant message for {tenant_id}") + + # Wait for tenant queue creation + await asyncio.sleep(2) + + # Test service request + service_request = upload_json_and_make_message_body(tenant_id) + await handler.input_exchange.publish(Message(body=json.dumps(service_request).encode()), routing_key=tenant_id) + logger.info(f"Sent service request for {tenant_id}") + + # Wait for message processing + await asyncio.sleep(5) + + # Test tenant deletion + delete_message = tenant_event_message(tenant_id) + await handler.tenant_exchange.publish( + Message(body=json.dumps(delete_message).encode()), routing_key="tenant.delete" + ) + logger.info(f"Sent delete tenant message for {tenant_id}") + + # Wait for tenant queue deletion + await asyncio.sleep(2) + + await handler.connection.close() + + if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(test_rabbitmq_handler())