diff --git a/pyinfra/visitor.py b/pyinfra/visitor.py index d235c11..5eccb94 100644 --- a/pyinfra/visitor.py +++ b/pyinfra/visitor.py @@ -2,16 +2,16 @@ import abc import gzip import json import logging +import random from collections import deque from operator import itemgetter -import random -from typing import Callable, Generator +from typing import Callable -from funcy import omit, pluck, first, lmap +from funcy import omit, lmap from pyinfra.config import CONFIG, parse_disjunction_string from pyinfra.exceptions import DataLoadingFailure -from pyinfra.server.packing import string_to_bytes +from pyinfra.server.packing import string_to_bytes, bytes_to_string from pyinfra.storage.storage import Storage @@ -34,7 +34,8 @@ def get_object_descriptor(body): def get_response_object_descriptor(body): return { "bucket_name": parse_disjunction_string(CONFIG.storage.bucket), - "object_name": get_response_object_name(body)+str(random.randint(0, 100)), # TODO: this random suffix should be built by some policy + "object_name": get_response_object_name(body) + + str(random.randint(0, 100)), # TODO: this random suffix should be built by some policy } @@ -81,8 +82,14 @@ class IdentifierDispatchCallback(DispatchCallback): return identifier != self.identifier - def __call__(self, payload): - return self.has_new_identifier(payload) + # def data_is_non_empty(self, data): + # + # if isinstance(data, str): + # self.put_object(data, metadata) + + def __call__(self, metadata): + + return self.has_new_identifier(metadata) class AggregationStorageStrategy(ResponseStrategy): @@ -116,11 +123,15 @@ class AggregationStorageStrategy(ResponseStrategy): self.upload_queue_items(metadata) def handle_response(self, payload, final=False): - metadata = omit(payload, ["data"]) - data = payload["data"] - for item in data: - self.upload_or_aggregate(item, metadata) - return metadata + request_metadata = omit(payload, ["result_data"]) + result_data = payload["result_data"] + for item in result_data: + self.upload_or_aggregate(item, request_metadata) + return request_metadata + + +class InvalidStorageItemFormat(ValueError): + pass class QueueVisitor: @@ -129,29 +140,63 @@ class QueueVisitor: self.callback = callback self.response_strategy = response_strategy - def load_data(self, body): - def download(): - logging.debug(f"Downloading {object_descriptor}...") - data = self.storage.get_object(**object_descriptor) - logging.debug(f"Downloaded {object_descriptor}.") - return data - - object_descriptor = get_object_descriptor(body) - + def download(self, object_descriptor): try: - return gzip.decompress(download()) + data = self.storage.get_object(**object_descriptor) except Exception as err: logging.warning(f"Loading data from storage failed for {object_descriptor}.") raise DataLoadingFailure() from err - def process_data(self, data, body): - return self.callback({"data": data, "metadata": body}) + return data + + @staticmethod + def standardize(data: bytes): + """Storage items can be a blob or a blob with metadata. Standardizes to the latter. + + Cases: + 1) backend upload: data as bytes + 2) Some Python service's upload: data as bytes of a json string "{'data': , 'metadata': }", + where value of key 'data' was encoded with bytes_to_string(...) + + TODO: + This is really kinda wonky. + """ + + def validate(data): + if not ("data" in data and "metadata" in data): + raise InvalidStorageItemFormat(f"Expected a mapping with keys 'data' and 'metadata', got {data}.") + + def wrap(data): + return {"data": data, "metadata": {}} + + try: + data = json.loads(data.decode()) + if not isinstance(data, dict): # case 1 + return wrap(string_to_bytes(data)) + else: # case 2 + validate(data) + data["data"] = string_to_bytes(data["data"]) + return data + except json.JSONDecodeError: # case 1 fallback + wrap(data.decode()) + + def load_data(self, body): + object_descriptor = get_object_descriptor(body) + logging.debug(f"Downloading {object_descriptor}...") + data = self.download(object_descriptor) + logging.debug(f"Downloaded {object_descriptor}.") + data = gzip.decompress(data) + data = self.standardize(data) + return data + + def process_data(self, data_metadata_pack): + return self.callback(data_metadata_pack) def load_and_process(self, body): data_from_storage = self.load_data(body) - result = self.process_data(data_from_storage, body) + result = self.process_data(data_from_storage) result = lmap(json.dumps, result) - result_body = {"data": result, **body} + result_body = {"result_data": result, **body} return result_body def __call__(self, body): diff --git a/test/integration_tests/serve_test.py b/test/integration_tests/serve_test.py index 697fbef..4450d06 100644 --- a/test/integration_tests/serve_test.py +++ b/test/integration_tests/serve_test.py @@ -1,34 +1,51 @@ import gzip import json import logging +from itertools import starmap import pytest -from funcy import notnone, filter, lfilter +from funcy import notnone, filter, lfilter, lmap, compose from pyinfra.default_objects import get_visitor, get_queue_manager, get_storage, get_consumer, get_callback from pyinfra.server.dispatcher.dispatcher import Nothing -from pyinfra.server.packing import string_to_bytes +from pyinfra.server.packer.packers.identity import bundle +from pyinfra.server.packing import string_to_bytes, bytes_to_string, unpack, pack from pyinfra.visitor import get_object_descriptor from test.utils.input import adorn_data_with_storage_info logger = logging.getLogger(__name__) -@pytest.mark.parametrize("one_to_many", [True]) +@pytest.mark.parametrize("one_to_many", [False, True]) @pytest.mark.parametrize("analysis_task", [False]) -@pytest.mark.parametrize("n_items", [1]) +@pytest.mark.parametrize("n_items", [2]) @pytest.mark.parametrize("n_pages", [1]) @pytest.mark.parametrize("buffer_size", [2]) +@pytest.mark.parametrize( + "storage_item_has_metadata", + [ + True, + False + ], +) @pytest.mark.parametrize( "item_type", [ - # "string", + "string", "image", ], ) -def test_serving(server_process, input_data_items, bucket_name, endpoint, core_operation): +def test_serving( + server_process, + input_data_items, + metadata, + bucket_name, + endpoint, + core_operation, + targets, + storage_item_has_metadata, +): print() - print(core_operation.__name__) callback = get_callback(endpoint) visitor = get_visitor(callback) @@ -39,17 +56,20 @@ def test_serving(server_process, input_data_items, bucket_name, endpoint, core_o queue_manager.clear() storage.clear_bucket(bucket_name) - items = adorn_data_with_storage_info(input_data_items) + if storage_item_has_metadata: + data_metadata_packs = starmap(compose(lambda s: s.encode(), json.dumps, pack), zip(input_data_items, metadata)) + else: + data_metadata_packs = map(compose(lambda s: s.encode(), json.dumps, bytes_to_string), input_data_items) - for data, message in items: + adorned_data_metadata_packs = adorn_data_with_storage_info(data_metadata_packs) + + for data, message in adorned_data_metadata_packs: storage.put_object(**get_object_descriptor(message), data=gzip.compress(data)) queue_manager.publish_request(message) reqs = consumer.consume(inactivity_timeout=5) - for itm, req in zip(items, reqs): - logger.debug(f"Processing item {itm}") - print(f"Processing item") + for itm, req in zip(adorned_data_metadata_packs, reqs): queue_manager.publish_response(req, visitor) def decode(storage_item): @@ -59,12 +79,13 @@ def test_serving(server_process, input_data_items, bucket_name, endpoint, core_o except json.decoder.JSONDecodeError: return None - - print(list(storage.get_all_object_names(bucket_name))) + # print(list(storage.get_all_object_names(bucket_name))) names_of_uploaded_files = lfilter(".out", storage.get_all_object_names(bucket_name)) uploaded_files = [storage.get_object(bucket_name, fn) for fn in names_of_uploaded_files] - print(names_of_uploaded_files) + # print(names_of_uploaded_files) for storage_item in [*map(decode, uploaded_files)]: storage_item["data"] = string_to_bytes(storage_item["data"]) - print(storage_item) + print("si", storage_item) + + # print(targets)