diff --git a/pyinfra/server/client_pipeline.py b/pyinfra/server/client_pipeline.py index c53d4e5..7a39dc4 100644 --- a/pyinfra/server/client_pipeline.py +++ b/pyinfra/server/client_pipeline.py @@ -12,4 +12,4 @@ class ClientPipeline: ) def __call__(self, *args, **kwargs): - return self.pipe(*args, **kwargs) + yield from self.pipe(*args, **kwargs) diff --git a/pyinfra/server/processor/processor.py b/pyinfra/server/processor/processor.py index 0c030ad..9dfd89a 100644 --- a/pyinfra/server/processor/processor.py +++ b/pyinfra/server/processor/processor.py @@ -4,6 +4,13 @@ from typing import Union, Any from pyinfra.server.dispatcher.dispatcher import Nothing +def delay(fn, *args, **kwargs): + def inner(): + return fn(*args, **kwargs) + + return inner + + class OnDemandProcessor: def __init__(self, fn): """Function `fn` has to return an iterable and ideally is a generator.""" diff --git a/test/unit_tests/server/pipeline_test.py b/test/unit_tests/server/pipeline_test.py index 58d5c3c..f4cbc88 100644 --- a/test/unit_tests/server/pipeline_test.py +++ b/test/unit_tests/server/pipeline_test.py @@ -5,11 +5,11 @@ from pyinfra.server.dispatcher.dispatchers.forwarding import ForwardingDispatche from pyinfra.server.dispatcher.dispatchers.rest import RestDispatcher from pyinfra.server.interpreter.interpreters.identity import IdentityInterpreter from pyinfra.server.interpreter.interpreters.rest_callback import RestPickupStreamer -from pyinfra.server.packer.packers.identity import IdentityPacker from pyinfra.server.packer.packers.rest import RestPacker from pyinfra.server.client_pipeline import ClientPipeline from pyinfra.server.receiver.receivers.identity import IdentityReceiver from pyinfra.server.receiver.receivers.rest import RestReceiver +from pyinfra.server.server import make_streamable from pyinfra.server.utils import unpack from pyinfra.utils.func import lift @@ -32,12 +32,35 @@ def test_mock_pipeline(): "basic", ], ) +@pytest.mark.parametrize("item_type", ["string"]) def test_pipeline(client_pipeline, input_data_items, metadata, target_data_items): output = client_pipeline(input_data_items, metadata) output = lmap(unpack, output) assert output == target_data_items +@pytest.mark.parametrize("item_type", ["string"]) +@pytest.mark.parametrize("n_items", [1]) +def test_pipeline_is_lazy(input_data_items, metadata): + def lazy_test_fn(*args, **kwargs): + probe["executed"] = True + return b"null", {} + + probe = {"executed": False} + processor_fn = make_streamable(lazy_test_fn, buffer_size=3, batched=False) + + client_pipeline = ClientPipeline( + RestPacker(), ForwardingDispatcher(processor_fn), IdentityReceiver(), IdentityInterpreter() + ) + output = client_pipeline(input_data_items, metadata) + + assert not probe["executed"] + + list(output) + + assert probe["executed"] + + @pytest.fixture def client_pipeline(rest_client_pipeline, basic_client_pipeline, client_pipeline_type): if client_pipeline_type == "rest":