diff --git a/pyinfra/server/dispatcher/dispatchers/forwarding.py b/pyinfra/server/dispatcher/dispatchers/forwarding.py new file mode 100644 index 0000000..05328ba --- /dev/null +++ b/pyinfra/server/dispatcher/dispatchers/forwarding.py @@ -0,0 +1,12 @@ +from pyinfra.server.dispatcher.dispatcher import Dispatcher + + +class ForwardingDispatcher(Dispatcher): + def __init__(self, fn): + self.fn = fn + + def patch(self, package): + return self.fn(package, final=False) + + def post(self, package): + return self.fn(package, final=True) diff --git a/pyinfra/server/receiver/receivers/identity.py b/pyinfra/server/receiver/receivers/identity.py new file mode 100644 index 0000000..5abb22c --- /dev/null +++ b/pyinfra/server/receiver/receivers/identity.py @@ -0,0 +1,9 @@ +from typing import Iterable + +from pyinfra.server.receiver.receiver import Receiver + + +class IdentityReceiver(Receiver): + def __call__(self, responses: Iterable): + for response in responses: + yield response diff --git a/pyinfra/server/server.py b/pyinfra/server/server.py index e46d611..a5db6c5 100644 --- a/pyinfra/server/server.py +++ b/pyinfra/server/server.py @@ -12,7 +12,7 @@ from pyinfra.utils.func import starlift, star, lift logger = logging.getLogger() -class ServerPipeline: +class StreamProcessor: def __init__(self, fn): """Function `fn` has to return an iterable and ideally is a generator.""" self.pipe = rcompose( @@ -29,7 +29,7 @@ class ServerPipeline: def make_streamable(operation, buffer_size, batched): operation = operation if batched else starlift(operation) - operation = ServerPipeline(operation) + operation = StreamProcessor(operation) operation = BufferedProcessor(operation, buffer_size=buffer_size) operation = compose(flatten, operation) diff --git a/test/fixtures/input.py b/test/fixtures/input.py index 4795d63..ecab88f 100644 --- a/test/fixtures/input.py +++ b/test/fixtures/input.py @@ -42,6 +42,11 @@ def endpoint(url): return f"{url}/submit" +@pytest.fixture(params=["rest", "basic"]) +def client_pipeline_type(request): + return request.param + + @pytest.fixture(params=[1, 0, 5]) def n_items(request): return request.param diff --git a/test/unit_tests/server/pipeline_test.py b/test/unit_tests/server/pipeline_test.py index a07a1c1..58d5c3c 100644 --- a/test/unit_tests/server/pipeline_test.py +++ b/test/unit_tests/server/pipeline_test.py @@ -1,10 +1,14 @@ import pytest from funcy import rcompose, lmap +from pyinfra.server.dispatcher.dispatchers.forwarding import ForwardingDispatcher from pyinfra.server.dispatcher.dispatchers.rest import RestDispatcher +from pyinfra.server.interpreter.interpreters.identity import IdentityInterpreter from pyinfra.server.interpreter.interpreters.rest_callback import RestPickupStreamer +from pyinfra.server.packer.packers.identity import IdentityPacker from pyinfra.server.packer.packers.rest import RestPacker from pyinfra.server.client_pipeline import ClientPipeline +from pyinfra.server.receiver.receivers.identity import IdentityReceiver from pyinfra.server.receiver.receivers.rest import RestReceiver from pyinfra.server.utils import unpack from pyinfra.utils.func import lift @@ -21,6 +25,13 @@ def test_mock_pipeline(): assert list(pipeline(data)) == list(rcompose(f, g, h, u)(data)) +@pytest.mark.parametrize( + "client_pipeline_type", + [ + "rest", + "basic", + ], +) def test_pipeline(client_pipeline, input_data_items, metadata, target_data_items): output = client_pipeline(input_data_items, metadata) output = lmap(unpack, output) @@ -28,15 +39,24 @@ def test_pipeline(client_pipeline, input_data_items, metadata, target_data_items @pytest.fixture -def client_pipeline(rest_client_pipeline): - return rest_client_pipeline +def client_pipeline(rest_client_pipeline, basic_client_pipeline, client_pipeline_type): + if client_pipeline_type == "rest": + return rest_client_pipeline + elif client_pipeline_type == "basic": + return basic_client_pipeline @pytest.fixture def rest_client_pipeline(server_process, endpoint, rest_interpreter): + """Requires a webserver to listen on `endpoint`""" return ClientPipeline(RestPacker(), RestDispatcher(endpoint), RestReceiver(), rest_interpreter) +@pytest.fixture +def basic_client_pipeline(endpoint, rest_interpreter, processor_fn): + return ClientPipeline(RestPacker(), ForwardingDispatcher(processor_fn), IdentityReceiver(), IdentityInterpreter()) + + @pytest.fixture def rest_interpreter(): return rcompose(RestPickupStreamer(), RestReceiver())