diff --git a/pyinfra/rest.py b/pyinfra/rest.py index b44eedf..698dfbc 100644 --- a/pyinfra/rest.py +++ b/pyinfra/rest.py @@ -1,10 +1,9 @@ -from itertools import chain, tee from operator import itemgetter from typing import Iterable -from funcy import compose, first, rcompose, flatten +from funcy import compose, first, flatten -from pyinfra.utils.func import star, lift, lstarlift, starlift +from pyinfra.utils.func import star, lift, lstarlift from test.utils.server import bytes_to_string, string_to_bytes @@ -23,28 +22,35 @@ def bundle(data: bytes, metadata: dict): return package +def normalize_up(x): + return [x] if isinstance(x, tuple) else x + + +def normalize_down(itr): + + head = first(itr) + + if not head: + return [] + elif isinstance(head, tuple): + return head, *itr + else: + return flatten((head, *itr)) + + def unpack_op_pack(operation): - return compose(inspect("A2"), flatten, inspect("A1"), lstarlift(pack), star(operation), unpack) + return compose(lstarlift(pack), normalize_up, star(operation), unpack) def unpack_batchop_pack(operation): - raise BrokenPipeError - # return rcompose( - # lift(unpack), # unpack the buffer items - # operation, # apply operation on unpacked items - # flatten, # operations may be 1 -> 1, 1 -> n or n -> 1, hence flatten - # lstarlift(pack), - # ) + return compose(lstarlift(pack), normalize_down, operation, lift(unpack)) def inspect(msg="ins"): def inner(x): if isinstance(x, Iterable) and not isinstance(x, dict): - print(11111111111111111111) x = list(x) - else: - print("00000000000") print(msg, x) diff --git a/test/exploration_tests/partial_response_test.py b/test/exploration_tests/partial_response_test.py index 84fa242..c93d9bb 100644 --- a/test/exploration_tests/partial_response_test.py +++ b/test/exploration_tests/partial_response_test.py @@ -6,8 +6,7 @@ from typing import Iterable import pytest import requests -from funcy import rcompose, compose, rpartial, identity, lmap, ilen, first -from more_itertools import flatten +from funcy import rcompose, compose, rpartial, identity, lmap, ilen, first, flatten from pyinfra.rest import pack, unpack, bundle, inspect from pyinfra.utils.func import lift, starlift, parallel_map, star, lstarlift @@ -22,7 +21,9 @@ def dispatch_methods(input_data): def post_partial(url, input_data: Iterable[bytes], metadata): def send(method, data): - return method(url, json=data) + response = method(url, json=data) + response.raise_for_status() + return response pack_data_and_metadata_for_rest_transfer = lift(rpartial(pack, metadata)) dispatch_http_method_left_and_forward_data_right = parallel_map(dispatch_methods, lift(identity)) @@ -34,11 +35,8 @@ def post_partial(url, input_data: Iterable[bytes], metadata): pack_data_and_metadata_for_rest_transfer, dispatch_http_method_left_and_forward_data_right, send_data_with_method_to_analyzer, - inspect("B"), extract_payload_from_responses, - inspect("C"), flatten_buffered_payloads, - inspect("D"), ) return input_data_to_result_data(input_data) @@ -46,14 +44,8 @@ def post_partial(url, input_data: Iterable[bytes], metadata): @pytest.mark.parametrize("item_type", ["string"]) def test_sending_partial_request(url, data_items, metadata, operation, server_process): - op = compose(lstarlift(pack), partial(operation, metadata=metadata)) - expected = list(flatten(map(op, data_items))) - print() - print("exp") - print(expected) + op = compose(star(pack), partial(operation, metadata=metadata)) + expected = lmap(op, data_items) output = list(post_partial(f"{url}/process", data_items, metadata)) - print() - print("out") - print(output) assert output == expected diff --git a/test/fixtures/server.py b/test/fixtures/server.py index bf5940a..f3245c7 100644 --- a/test/fixtures/server.py +++ b/test/fixtures/server.py @@ -58,7 +58,7 @@ def processor_fn(operation, buffer_size, batched): @pytest.fixture def operation(item_type, batched): def upper(string: bytes, metadata): - return [(string.decode().upper().encode(), metadata)] + return string.decode().upper().encode(), metadata def rotate(im: bytes, metadata): im = Image.open(io.BytesIO(im)) @@ -79,7 +79,7 @@ def item_type(request): return request.param -@pytest.fixture(params=[False]) +@pytest.fixture(params=[False, True]) def batched(request): """Controls, whether the buffer processor function of the webserver is applied to batches or single items.""" return request.param