77 lines
2.3 KiB
Python
77 lines
2.3 KiB
Python
import pytest
|
|
from funcy import repeatedly, takewhile, notnone, lmap, lmapcat, lflatten
|
|
|
|
from pyinfra.server.buffering.stream import FlatStreamBuffer, StreamBuffer
|
|
from pyinfra.server.dispatcher.dispatcher import Nothing
|
|
from pyinfra.server.stream.queued_stream_function import QueuedStreamFunction
|
|
from pyinfra.utils.func import lift, foreach, starlift
|
|
|
|
|
|
@pytest.fixture
|
|
def func(one_to_many):
|
|
def fn(x):
|
|
y = x**2
|
|
return y if not one_to_many else (y, y)
|
|
|
|
return fn
|
|
|
|
|
|
def test_stream_buffer(func, inputs, outputs, buffer_size):
|
|
stream_buffer = StreamBuffer(lift(func), buffer_size=buffer_size)
|
|
|
|
assert lmapcat(stream_buffer, (*inputs, Nothing)) == outputs
|
|
assert lmapcat(stream_buffer, [Nothing]) == []
|
|
|
|
|
|
@pytest.mark.parametrize("n_items", [1])
|
|
def test_stream_buffer_catches_type_error(func, inputs, outputs):
|
|
|
|
stream_buffer = StreamBuffer(func)
|
|
|
|
with pytest.raises(TypeError):
|
|
lmapcat(stream_buffer, inputs)
|
|
|
|
|
|
def test_flat_stream_buffer(func, inputs, outputs, buffer_size):
|
|
flat_stream_buffer = FlatStreamBuffer(lift(func), buffer_size=buffer_size)
|
|
|
|
assert list(flat_stream_buffer(inputs)) == lflatten(outputs)
|
|
assert list(flat_stream_buffer([])) == []
|
|
|
|
|
|
@pytest.mark.xfail(reason="input wrong format, TODO: redesign input fixture hierarchy")
|
|
def test_flat_stream_buffer_on_different_data(
|
|
core_operation, input_data_items, metadata, target_data_items, buffer_size, item_type, one_to_many
|
|
):
|
|
|
|
if core_operation is Nothing:
|
|
pytest.skip(f"No operation defined for parameter combination: {item_type=}, {one_to_many=}")
|
|
|
|
flat_stream_buffer = FlatStreamBuffer(starlift(core_operation), buffer_size=buffer_size)
|
|
|
|
assert list(flat_stream_buffer(zip(input_data_items, metadata))) == target_data_items
|
|
assert list(flat_stream_buffer([])) == []
|
|
|
|
|
|
def test_queued_stream_function(func, inputs, outputs):
|
|
queued_stream_function = QueuedStreamFunction(lift(func))
|
|
|
|
foreach(queued_stream_function.push, inputs)
|
|
|
|
assert list(takewhile(notnone, repeatedly(queued_stream_function.pop))) == outputs
|
|
|
|
|
|
@pytest.fixture
|
|
def inputs(n_items):
|
|
return range(n_items)
|
|
|
|
|
|
@pytest.fixture
|
|
def outputs(inputs, func):
|
|
return lmap(func, inputs)
|
|
|
|
|
|
@pytest.fixture(params=[1, 3, 10, 12])
|
|
def buffer_size(request):
|
|
return request.param
|