diff --git a/pyinfra/utils/buffer.py b/pyinfra/utils/buffer.py index dab68da..4603d34 100644 --- a/pyinfra/utils/buffer.py +++ b/pyinfra/utils/buffer.py @@ -4,7 +4,7 @@ from typing import Any from funcy import repeatedly -logger = logging.getLogger() +logger = logging.getLogger(__name__) def bufferize(fn, buffer_size=3, persist_fn=lambda x: x): @@ -14,8 +14,10 @@ def bufferize(fn, buffer_size=3, persist_fn=lambda x: x): return response_payload def buffer_full(current_buffer_size): + # TODO: this assert does not hold for receiver test, unclear why + # assert current_buffer_size <= buffer_size if current_buffer_size > buffer_size: - logger.warning(f"Overfull buffer: size: {current_buffer_size}; intended capacity: {buffer_size}") + logger.warning(f"Overfull buffer. size: {current_buffer_size}; intended capacity: {buffer_size}") return current_buffer_size == buffer_size def n_items_to_pop(buffer, final): diff --git a/test/unit_tests/buffer_test.py b/test/unit_tests/buffer_test.py new file mode 100644 index 0000000..ad929fc --- /dev/null +++ b/test/unit_tests/buffer_test.py @@ -0,0 +1,19 @@ +from funcy import compose, lmapcat + +from pyinfra.utils.buffer import bufferize + + +def test_buffer(): + def buffer_mean(xs): + return [sum(xs) / len(xs)] if xs else [] + + buffer_mean = bufferize(compose(buffer_mean, list), buffer_size=3) + ys = lmapcat(buffer_mean, range(20)) + assert list(ys) == [1.0, 4.0, 7.0, 10.0, 13.0, 16.0] + + def reverse_buffer(xs): + return reversed(xs) + + reverse_buffer = bufferize(compose(reverse_buffer, list), buffer_size=3) + ys = lmapcat(reverse_buffer, range(10)) + assert ys == [2, 1, 0, 5, 4, 3, 8, 7, 6]