diff --git a/pyinfra/server/bufferizer/lazy_bufferizer.py b/pyinfra/server/bufferizer/lazy_bufferizer.py index ad40312..d394cca 100644 --- a/pyinfra/server/bufferizer/lazy_bufferizer.py +++ b/pyinfra/server/bufferizer/lazy_bufferizer.py @@ -54,13 +54,13 @@ class StreamBuffer: yield from takewhile(is_not_nothing, repeatedly(self.pop)) def push(self, item): + self.result_stream = chain(self.result_stream, self.compute(item)) + + def compute(self, item): try: - self.result_stream = chain(self.result_stream, self.compute(item)) + yield from self.fn(item) except TypeError as err: raise TypeError("Function failed with type error. Is it mappable?") from err - def compute(self, item): - yield from self.fn(item) - def pop(self): return first(chain(self.result_stream, [Nothing])) diff --git a/pyinfra/utils/func.py b/pyinfra/utils/func.py index fed8161..cf5e308 100644 --- a/pyinfra/utils/func.py +++ b/pyinfra/utils/func.py @@ -31,6 +31,11 @@ def duplicate_stream_and_apply(f1, f2): return compose(star(parallel(f1, f2)), tee) +def foreach(fn, iterable): + for itm in iterable: + fn(itm) + + def parallel_map(f1, f2): """Applies functions to a stream in parallel and yields a stream of tuples: parallel_map :: a -> b, a -> c -> [a] -> [(b, c)] diff --git a/test/unit_tests/server/stream_buffer_test.py b/test/unit_tests/server/stream_buffer_test.py index 966e3c9..fb43386 100644 --- a/test/unit_tests/server/stream_buffer_test.py +++ b/test/unit_tests/server/stream_buffer_test.py @@ -1,17 +1,62 @@ -from itertools import chain - import pytest +from funcy import repeatedly, takewhile, notnone, lmap, lmapcat -from pyinfra.server.bufferizer.lazy_bufferizer import StreamBuffer +from pyinfra.server.bufferizer.lazy_bufferizer import FlatStreamBuffer, StreamBuffer from pyinfra.server.dispatcher.dispatcher import Nothing -from pyinfra.utils.func import lift +from pyinfra.server.server import LazyProcessor +from pyinfra.utils.func import lift, foreach -@pytest.mark.parametrize("buffer_size", [0, 1, 3, 10, 12]) -def test_stream_buffer(buffer_size): - def func(x): +@pytest.fixture +def func(): + def fn(x): return x ** 2 - func = StreamBuffer(lift(func)) + return fn - assert list(chain(*map(func, [*range(10), Nothing]))) == [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] + +def test_stream_buffer(func, inputs, outputs, buffer_size): + stream_buffer = StreamBuffer(lift(func), buffer_size=buffer_size) + + assert lmapcat(stream_buffer, (*inputs, Nothing)) == outputs + assert lmapcat(stream_buffer, [Nothing]) == [] + + +@pytest.mark.parametrize("n_items", [1]) +def test_stream_buffer_catches_type_error(func, inputs, outputs): + + stream_buffer = StreamBuffer(func) + + with pytest.raises(TypeError): + lmapcat(stream_buffer, inputs) + + +def test_flat_stream_buffer(func, inputs, outputs, buffer_size): + flat_stream_buffer = FlatStreamBuffer(lift(func), buffer_size=buffer_size) + + assert list(flat_stream_buffer(inputs)) == outputs + assert list(flat_stream_buffer([])) == [] + + +def test_lazy_processor(func, inputs, outputs): + stream_buffer = FlatStreamBuffer(lift(func)) + lazy_processor = LazyProcessor(stream_buffer) + + foreach(lazy_processor.push, inputs) + + assert list(takewhile(notnone, repeatedly(lazy_processor.pop))) == outputs + + +@pytest.fixture +def inputs(n_items): + return range(n_items) + + +@pytest.fixture +def outputs(inputs, func): + return lmap(func, inputs) + + +@pytest.fixture(params=[0, 1, 3, 10, 12]) +def buffer_size(request): + return request.param