import pytest from funcy import repeatedly, takewhile, notnone, lmap, lmapcat, lflatten from pyinfra.server.buffering.stream import FlatStreamBuffer, StreamBuffer from pyinfra.server.dispatcher.dispatcher import Nothing from pyinfra.server.stream.queued_stream_function import QueuedStreamFunction from pyinfra.utils.func import lift, foreach, starlift @pytest.fixture def func(one_to_many): def fn(x): y = x**2 return y if not one_to_many else (y, y) return fn def test_stream_buffer(func, inputs, outputs, buffer_size): stream_buffer = StreamBuffer(lift(func), buffer_size=buffer_size) assert lmapcat(stream_buffer, (*inputs, Nothing)) == outputs assert lmapcat(stream_buffer, [Nothing]) == [] @pytest.mark.parametrize("n_items", [1]) def test_stream_buffer_catches_type_error(func, inputs, outputs): stream_buffer = StreamBuffer(func) with pytest.raises(TypeError): lmapcat(stream_buffer, inputs) def test_flat_stream_buffer(func, inputs, outputs, buffer_size): flat_stream_buffer = FlatStreamBuffer(lift(func), buffer_size=buffer_size) assert list(flat_stream_buffer(inputs)) == lflatten(outputs) assert list(flat_stream_buffer([])) == [] @pytest.mark.xfail(reason="input wrong format, TODO: redesign input fixture hierarchy") def test_flat_stream_buffer_on_different_data( core_operation, input_data_items, metadata, target_data_items, buffer_size, item_type, one_to_many ): if core_operation is Nothing: pytest.skip(f"No operation defined for parameter combination: {item_type=}, {one_to_many=}") flat_stream_buffer = FlatStreamBuffer(starlift(core_operation), buffer_size=buffer_size) assert list(flat_stream_buffer(zip(input_data_items, metadata))) == target_data_items assert list(flat_stream_buffer([])) == [] def test_queued_stream_function(func, inputs, outputs): queued_stream_function = QueuedStreamFunction(lift(func)) foreach(queued_stream_function.push, inputs) assert list(takewhile(notnone, repeatedly(queued_stream_function.pop))) == outputs @pytest.fixture def inputs(n_items): return range(n_items) @pytest.fixture def outputs(inputs, func): return lmap(func, inputs) @pytest.fixture(params=[1, 3, 10, 12]) def buffer_size(request): return request.param