diff --git a/pyinfra/queue/queue_manager/pika_queue_manager.py b/pyinfra/queue/queue_manager/pika_queue_manager.py index c38d91f..27f82e3 100644 --- a/pyinfra/queue/queue_manager/pika_queue_manager.py +++ b/pyinfra/queue/queue_manager/pika_queue_manager.py @@ -108,9 +108,18 @@ class PikaQueueManager(QueueManager): n_attempts = get_n_previous_attempts(properties) + 1 try: - response = json.dumps(visitor(json.loads(body))).encode() - self.channel.basic_publish("", self._output_queue, response) + + response_messages = visitor(json.loads(body)) + + if isinstance(response_messages, dict): + response_messages = [response_messages] + + for response_message in response_messages: + response_message = json.dumps(response_message).encode() + self.channel.basic_publish("", self._output_queue, response_message) + self.channel.basic_ack(frame.delivery_tag) + except ProcessingFailure: logger.error(f"Message failed to process {n_attempts}/{max_attempts} times: {body}") diff --git a/pyinfra/visitor.py b/pyinfra/visitor.py index c14f97b..98857f3 100644 --- a/pyinfra/visitor.py +++ b/pyinfra/visitor.py @@ -122,12 +122,12 @@ class AggregationStorageStrategy(ResponseStrategy): self.dispatch_callback = dispatch_callback or IdentifierDispatchCallback() self.buffer = deque() - def put_object(self, data: bytes, metadata): - object_descriptor = get_response_object_descriptor(metadata) + def put_object(self, data: bytes, storage_upload_info): + object_descriptor = get_response_object_descriptor(storage_upload_info) # TODO: object_descriptor needs suffix + # Note: what did I mean with that? self.storage.put_object(**object_descriptor, data=data) - - # body["responseFile"] = response_object_descriptor["object_name"] + return {**storage_upload_info, "responseFile": object_descriptor["object_name"]} def merge_queue_items(self): merged_buffer_content = self.merger(self.buffer) @@ -136,7 +136,7 @@ class AggregationStorageStrategy(ResponseStrategy): def upload_queue_items(self, storage_upload_info): data = json.dumps(self.merge_queue_items()).encode() - self.put_object(data, storage_upload_info) + return self.put_object(data, storage_upload_info) def upload_or_aggregate(self, analysis_payload, request_metadata, last=False): """ @@ -146,20 +146,18 @@ class AggregationStorageStrategy(ResponseStrategy): storage_upload_info = {**request_metadata, "id": analysis_payload.get("id", "0")} if analysis_payload["data"]: - self.put_object(json.dumps(analysis_payload).encode(), storage_upload_info) + return self.put_object(json.dumps(analysis_payload).encode(), storage_upload_info) else: self.buffer.append(analysis_payload) if last or self.dispatch_callback(request_metadata): - self.upload_queue_items(storage_upload_info) + return self.upload_queue_items(storage_upload_info) def handle_response(self, payload, final=False): request_metadata = omit(payload, ["data"]) result_data = peekable(payload["data"]) for analysis_payload in result_data: - self.upload_or_aggregate(analysis_payload, request_metadata, last=not result_data.peek(False)) - - return request_metadata + yield self.upload_or_aggregate(analysis_payload, request_metadata, last=not result_data.peek(False)) class InvalidStorageItemFormat(ValueError): diff --git a/test/integration_tests/serve_test.py b/test/integration_tests/serve_test.py index 12be128..b10f1da 100644 --- a/test/integration_tests/serve_test.py +++ b/test/integration_tests/serve_test.py @@ -6,7 +6,7 @@ from operator import itemgetter import pytest from frozendict import frozendict -from funcy import lfilter, compose, lzip +from funcy import lfilter, compose, lzip, pluck, lpluck from pyinfra.default_objects import ( get_callback, @@ -80,7 +80,7 @@ def decode(storage_item): "queue_manager_name", [ "mock", - # "pika", + "pika", ], scope="session", ) @@ -130,10 +130,7 @@ def test_serving( for _, req in zip(adorned_data_metadata_packs, reqs): queue_manager.publish_response(req, visitor) - # TODO: pull files by responseFile field from visitor() result - - names_of_uploaded_files = lfilter(".out", storage.get_all_object_names(bucket_name)) + names_of_uploaded_files = lpluck("responseFile", queue_manager.output_queue.to_list()) uploaded_files = [storage.get_object(bucket_name, fn) for fn in names_of_uploaded_files] - outputs = sorted(chain(*map(decode, uploaded_files)), key=itemgetter(0)) assert outputs == targets diff --git a/test/queue/queue_manager_mock.py b/test/queue/queue_manager_mock.py index 063a68b..fe50a8f 100644 --- a/test/queue/queue_manager_mock.py +++ b/test/queue/queue_manager_mock.py @@ -17,7 +17,14 @@ class QueueManagerMock(QueueManager): self._input_queue.append(request) def publish_response(self, message, callback): - self._output_queue.append(callback(message)) + + response_messages = callback(message) + + if isinstance(response_messages, dict): + response_messages = [response_messages] + + for response_message in response_messages: + self._output_queue.append(response_message) def pull_request(self): return self._input_queue.popleft()