refactoring: move; added null value param to bufferize

This commit is contained in:
Matthias Bisping 2022-05-09 14:50:54 +02:00
parent c2ed6d78b7
commit 5b913983eb
3 changed files with 32 additions and 51 deletions

View File

@ -1,8 +1,10 @@
import logging
from collections import deque
from itertools import chain
from funcy import repeatedly, identity, flatten
from funcy import flatten, compose, compact
from pyinfra.utils.buffer import bufferize
from pyinfra.utils.func import lift
logger = logging.getLogger(__name__)
@ -15,7 +17,7 @@ class OnDemandProcessor:
def __init__(self, fn):
"""Function `fn` has to return an iterable and ideally is a generator."""
self.execution_queue = chain([])
self.fn = hesitant_bufferize(fn)
self.fn = bufferize(fn)
def submit(self, package, **kwargs) -> None:
self.execution_queue = chain(self.execution_queue, [package])
@ -24,35 +26,6 @@ class OnDemandProcessor:
return next(self.compute())
def compute(self):
yield from flatten(map(self.helper, chain(self.execution_queue, [Nothing])))
items = chain(self.execution_queue, [Nothing])
yield from compose(flatten, compact, lift(self.fn))(items)
yield Nothing
def helper(self, packages):
return self.fn(packages)
def hesitant_bufferize(fn, buffer_size=3, persist_fn=identity):
def buffered_fn(item):
if item is not Nothing:
buffer.append(persist_fn(item))
response_payload = fn(repeatedly(buffer.popleft, n_items_to_pop(buffer, item is Nothing)))
return response_payload
def buffer_full(current_buffer_size):
# TODO: this assert does not hold for receiver test, unclear why
# assert current_buffer_size <= buffer_size
if current_buffer_size > buffer_size:
logger.warning(f"Overfull buffer. size: {current_buffer_size}; intended capacity: {buffer_size}")
return current_buffer_size == buffer_size
def n_items_to_pop(buffer, final):
current_buffer_size = len(buffer)
return (final or buffer_full(current_buffer_size)) * current_buffer_size
buffer = deque()
return buffered_fn

View File

@ -1,22 +1,29 @@
import logging
from collections import deque
from funcy import repeatedly
from funcy import repeatedly, identity
from pyinfra.server.dispatcher.dispatcher import Nothing
logger = logging.getLogger(__name__)
def bufferize(fn, buffer_size=3, persist_fn=lambda x: x):
def buffered_fn(item, final=False):
buffer.append(persist_fn(item))
response_payload = fn(repeatedly(buffer.popleft, n_items_to_pop(buffer, final)))
return response_payload
def bufferize(fn, buffer_size=3, persist_fn=identity, null_value=None):
def buffered_fn(item):
if item is not Nothing:
buffer.append(persist_fn(item))
response_payload = fn(repeatedly(buffer.popleft, n_items_to_pop(buffer, item is Nothing)))
return response_payload or null_value
def buffer_full(current_buffer_size):
# TODO: this assert does not hold for receiver test, unclear why
# assert current_buffer_size <= buffer_size
if current_buffer_size > buffer_size:
logger.warning(f"Overfull buffer. size: {current_buffer_size}; intended capacity: {buffer_size}")
return current_buffer_size == buffer_size
def n_items_to_pop(buffer, final):

View File

@ -1,5 +1,6 @@
from funcy import compose, lmapcat
from funcy import compose, lmapcat, compact, flatten
from pyinfra.server.dispatcher.dispatcher import Nothing
from pyinfra.utils.buffer import bufferize
@ -7,20 +8,20 @@ def test_buffer():
def buffer_mean(xs):
return [sum(xs) / len(xs)] if xs else []
buffer_mean = bufferize(compose(buffer_mean, list), buffer_size=3)
ys = lmapcat(buffer_mean, range(20))
assert list(ys) == [1.0, 4.0, 7.0, 10.0, 13.0, 16.0]
buffer_mean = bufferize(compose(buffer_mean, list), buffer_size=3, null_value=[])
ys = lmapcat(buffer_mean, (*range(20), Nothing))
assert list(ys) == [1.0, 4.0, 7.0, 10.0, 13.0, 16.0, 18.5]
def reverse_buffer(xs):
return reversed(xs)
return reversed(list(xs))
reverse_buffer = bufferize(compose(reverse_buffer, list), buffer_size=3)
ys = lmapcat(reverse_buffer, range(10))
assert ys == [2, 1, 0, 5, 4, 3, 8, 7, 6]
reverse_buffer = bufferize(reverse_buffer, buffer_size=3)
ys = flatten(compact(map(reverse_buffer, (*range(10), Nothing))))
assert list(ys) == [2, 1, 0, 5, 4, 3, 8, 7, 6, 9]
def buffer_sum(xs):
return [sum(xs)]
buffer_sum = bufferize(buffer_sum, buffer_size=3)
ys = lmapcat(buffer_sum, range(10))
assert ys == [0, 0, 3, 0, 0, 12, 0, 0, 21, 0]
buffer_sum = bufferize(buffer_sum, buffer_size=2)
ys = flatten(compact(map(buffer_sum, (*range(10), Nothing))))
assert list(ys) == [0, 1, 0, 5, 0, 9, 0, 13, 0, 17, 0]