diff --git a/pyinfra/default_objects.py b/pyinfra/default_objects.py index 3beee27..6d4d3e0 100644 --- a/pyinfra/default_objects.py +++ b/pyinfra/default_objects.py @@ -60,7 +60,9 @@ def make_callback(analysis_endpoint): # queue of the pipeline. Probably the pipeline return value needs to contains the queue message frame (or # so), in order for the queue manager to tell which message to ack. analysis_response_stream = pipeline([data], [metadata]) - return analysis_response_stream + # TODO: casting list is a temporary solution, while the client pipeline operates on singletons + # ([data], [metadata]). + return list(analysis_response_stream) except Exception as err: logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.") raise AnalysisFailure() from err diff --git a/pyinfra/visitor.py b/pyinfra/visitor.py index 5a0682a..c14f97b 100644 --- a/pyinfra/visitor.py +++ b/pyinfra/visitor.py @@ -6,9 +6,10 @@ import logging import time from collections import deque from operator import itemgetter -from typing import Callable +from typing import Callable, Iterable from funcy import omit +from more_itertools import peekable from pyinfra.config import CONFIG, parse_disjunction_string from pyinfra.exceptions import DataLoadingFailure @@ -32,10 +33,19 @@ def get_object_name(body): def get_response_object_name(body): + + if "pages" not in body: + body["pages"] = [] + + if "id" not in body: + body["id"] = 0 + dossier_id, file_id, pages, idnt, response_file_extension = itemgetter( "dossierId", "fileId", "pages", "id", "responseFileExtension" )(body) + object_name = f"{dossier_id}/{file_id}_{unique_hash(pages)}-id:{idnt}.{response_file_extension}" + return object_name @@ -64,8 +74,10 @@ class StorageStrategy(ResponseStrategy): self.storage = storage def handle_response(self, body): - self.storage.put_object(**get_response_object_descriptor(body), data=gzip.compress(json.dumps(body).encode())) + response_object_descriptor = get_response_object_descriptor(body) + self.storage.put_object(**response_object_descriptor, data=gzip.compress(json.dumps(body).encode())) body.pop("data") + body["responseFile"] = response_object_descriptor["object_name"] return body @@ -115,16 +127,18 @@ class AggregationStorageStrategy(ResponseStrategy): # TODO: object_descriptor needs suffix self.storage.put_object(**object_descriptor, data=data) + # body["responseFile"] = response_object_descriptor["object_name"] + def merge_queue_items(self): merged_buffer_content = self.merger(self.buffer) self.buffer.clear() return merged_buffer_content - def upload_queue_items(self, metadata): + def upload_queue_items(self, storage_upload_info): data = json.dumps(self.merge_queue_items()).encode() - self.put_object(data, metadata) + self.put_object(data, storage_upload_info) - def upload_or_aggregate(self, analysis_payload, request_metadata): + def upload_or_aggregate(self, analysis_payload, request_metadata, last=False): """ analysis_payload : {data: ..., metadata: ...} """ @@ -136,14 +150,15 @@ class AggregationStorageStrategy(ResponseStrategy): else: self.buffer.append(analysis_payload) - if self.dispatch_callback(request_metadata): + if last or self.dispatch_callback(request_metadata): self.upload_queue_items(storage_upload_info) def handle_response(self, payload, final=False): - request_metadata = omit(payload, ["result_data"]) - result_data = payload["result_data"] - for item in result_data: - self.upload_or_aggregate(item, request_metadata) + request_metadata = omit(payload, ["data"]) + result_data = peekable(payload["data"]) + for analysis_payload in result_data: + self.upload_or_aggregate(analysis_payload, request_metadata, last=not result_data.peek(False)) + return request_metadata @@ -188,14 +203,15 @@ class QueueVisitor: try: data = json.loads(data.decode()) - if not isinstance(data, dict): # case 1 - return wrap(string_to_bytes(data)) - else: # case 2 - validate(data) - data["data"] = string_to_bytes(data["data"]) - return data except json.JSONDecodeError: # case 1 fallback - wrap(data.decode()) + return wrap(data.decode()) + + if not isinstance(data, dict): # case 1 + return wrap(string_to_bytes(data)) + else: # case 2 + validate(data) + data["data"] = string_to_bytes(data["data"]) + return data def load_data(self, body): object_descriptor = get_object_descriptor(body) @@ -206,16 +222,16 @@ class QueueVisitor: data = self.standardize(data) return data - def process_data(self, data_metadata_pack): + def process_storage_item(self, data_metadata_pack): return self.callback(data_metadata_pack) - def load_and_process(self, body): - data_from_storage = self.load_data(body) - result = list(self.process_data(data_from_storage)) - # result = lmap(json.dumps, result) - result_body = {"result_data": result, **body} + def load_item_from_storage_and_process_with_callback(self, queue_item_body): + """Bundles the result from processing a storage item with the body of the corresponding queue item.""" + storage_item = self.load_data(queue_item_body) + result = self.process_storage_item(storage_item) + result_body = {"data": result, **queue_item_body} return result_body - def __call__(self, body): - result_body = self.load_and_process(body) - return self.response_strategy(result_body) + def __call__(self, queue_item_body): + analysis_result_body = self.load_item_from_storage_and_process_with_callback(queue_item_body) + return self.response_strategy(analysis_result_body) diff --git a/test/conftest.py b/test/conftest.py index 660f016..c7b60cd 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -154,7 +154,7 @@ def queue_manager(queue_manager_name, docker_compose): @pytest.fixture(scope="session") def callback(): def inner(request): - return request["data"].decode() * 2 + return [request["data"].decode() * 2] return inner diff --git a/test/integration_tests/serve_test.py b/test/integration_tests/serve_test.py index 394887c..12be128 100644 --- a/test/integration_tests/serve_test.py +++ b/test/integration_tests/serve_test.py @@ -58,7 +58,7 @@ def decode(storage_item): True, ], ) -@pytest.mark.parametrize("n_items", [2]) +@pytest.mark.parametrize("n_items", [1, 2]) @pytest.mark.parametrize("n_pages", [1]) @pytest.mark.parametrize("buffer_size", [2]) @pytest.mark.parametrize( @@ -130,6 +130,8 @@ def test_serving( for _, req in zip(adorned_data_metadata_packs, reqs): queue_manager.publish_response(req, visitor) + # TODO: pull files by responseFile field from visitor() result + names_of_uploaded_files = lfilter(".out", storage.get_all_object_names(bucket_name)) uploaded_files = [storage.get_object(bucket_name, fn) for fn in names_of_uploaded_files] diff --git a/test/unit_tests/queue_visitor_test.py b/test/unit_tests/queue_visitor_test.py index c781429..5af4c8c 100644 --- a/test/unit_tests/queue_visitor_test.py +++ b/test/unit_tests/queue_visitor_test.py @@ -3,6 +3,7 @@ import json import pytest +from pyinfra.server.packing import bytes_to_string from pyinfra.visitor import get_object_descriptor, get_response_object_descriptor @@ -20,19 +21,23 @@ class TestVisitor: storage.clear_bucket(bucket_name) storage.put_object(**get_object_descriptor(body), data=gzip.compress(b"content")) data_received = visitor.load_data(body) - assert b"content" == data_received + assert {"data": "content", "metadata": {}} == data_received @pytest.mark.parametrize("response_strategy_name", ["forwarding", "storage"], scope="session") def test_visitor_pulls_and_processes_data(self, visitor, body, storage, bucket_name): storage.clear_bucket(bucket_name) - storage.put_object(**get_object_descriptor(body), data=gzip.compress("2".encode())) - response_body = visitor.load_and_process(body) - assert response_body["data"] == "22" + storage.put_object(**get_object_descriptor(body), data=gzip.compress(json.dumps(bytes_to_string(b"2")).encode())) + response_body = visitor.load_item_from_storage_and_process_with_callback(body) + assert response_body["data"] == ["22"] @pytest.mark.parametrize("response_strategy_name", ["storage"], scope="session") def test_visitor_puts_response_on_storage(self, visitor, body, storage, bucket_name): storage.clear_bucket(bucket_name) - storage.put_object(**get_object_descriptor(body), data=gzip.compress("2".encode())) + storage.put_object( + **get_object_descriptor(body), data=gzip.compress(json.dumps(bytes_to_string(b"2")).encode()) + ) response_body = visitor(body) assert "data" not in response_body - assert json.loads(gzip.decompress(storage.get_object(**get_response_object_descriptor(body))))["data"] == "22" + assert json.loads( + gzip.decompress(storage.get_object(bucket_name=bucket_name, object_name=response_body["responseFile"])) + )["data"] == ["22"]