diff --git a/pyinfra/queue/consumer.py b/pyinfra/queue/consumer.py index 1072178..25c7272 100644 --- a/pyinfra/queue/consumer.py +++ b/pyinfra/queue/consumer.py @@ -2,15 +2,15 @@ from pyinfra.queue.queue_manager.queue_manager import QueueManager class Consumer: - def __init__(self, callback, queue_manager: QueueManager): + def __init__(self, visitor, queue_manager: QueueManager): self.queue_manager = queue_manager - self.callback = callback + self.visitor = visitor - def consume_and_publish(self): - self.queue_manager.consume_and_publish(self.callback) + def consume_and_publish(self, n=None): + self.queue_manager.consume_and_publish(self.visitor, n=n) def basic_consume_and_publish(self): - self.queue_manager.basic_consume_and_publish(self.callback) + self.queue_manager.basic_consume_and_publish(self.visitor) def consume(self, **kwargs): return self.queue_manager.consume(**kwargs) diff --git a/pyinfra/queue/queue_manager/pika_queue_manager.py b/pyinfra/queue/queue_manager/pika_queue_manager.py index 27f82e3..278e446 100644 --- a/pyinfra/queue/queue_manager/pika_queue_manager.py +++ b/pyinfra/queue/queue_manager/pika_queue_manager.py @@ -1,5 +1,6 @@ import json import logging +from itertools import islice import pika @@ -136,11 +137,11 @@ class PikaQueueManager(QueueManager): logger.debug("Consuming") return self.channel.consume(self._input_queue, inactivity_timeout=inactivity_timeout) - def consume_and_publish(self, visitor: QueueVisitor): + def consume_and_publish(self, visitor: QueueVisitor, n=None): logger.info(f"Consuming input queue.") - for message in self.consume(): + for message in islice(self.consume(), n): self.publish_response(message, visitor) def basic_consume_and_publish(self, visitor: QueueVisitor): diff --git a/pyinfra/queue/queue_manager/queue_manager.py b/pyinfra/queue/queue_manager/queue_manager.py index d42bfc5..f568029 100644 --- a/pyinfra/queue/queue_manager/queue_manager.py +++ b/pyinfra/queue/queue_manager/queue_manager.py @@ -43,7 +43,7 @@ class QueueManager(abc.ABC): raise NotImplementedError @abc.abstractmethod - def consume_and_publish(self, callback): + def consume_and_publish(self, callback, n=None): raise NotImplementedError @abc.abstractmethod diff --git a/test/integration_tests/serve_test.py b/test/integration_tests/serve_test.py index ae53d23..40c7578 100644 --- a/test/integration_tests/serve_test.py +++ b/test/integration_tests/serve_test.py @@ -6,7 +6,7 @@ from operator import itemgetter import pytest from frozendict import frozendict -from funcy import lfilter, compose, lzip, pluck, lpluck +from funcy import compose, lzip, lpluck from pyinfra.default_objects import ( get_callback, @@ -26,16 +26,35 @@ def freeze(data, metadata): @pytest.fixture -def components(url, bucket_name, queue_manager, storage): +def mixed_components(url, bucket_name, queue_manager, storage): callback = get_callback(url) - consumer = Consumer(callback, queue_manager) - visitor = QueueVisitor(storage, callback, get_response_strategy(storage)) + consumer = Consumer(visitor, queue_manager) return visitor, queue_manager, storage, consumer +@pytest.fixture +def real_components(url): + callback = get_callback(url) + + +@pytest.fixture +def components(components_type, real_components, mixed_components): + if components_type == "real": + return real_components + elif components_type == "mixed": + return mixed_components + else: + raise ValueError(f"Unknown components type '{components_type}'.") + + +@pytest.fixture(params=["real", "mixed"]) +def components_type(request): + return request.param + + def decode(storage_item): storage_item = json.loads(storage_item.decode()) if not isinstance(storage_item, list): @@ -79,7 +98,7 @@ def decode(storage_item): @pytest.mark.parametrize( "queue_manager_name", [ - "mock", + # "mock", "pika", ], scope="session", @@ -93,6 +112,12 @@ def decode(storage_item): ], scope="session", ) +@pytest.mark.parametrize( + "components_type", + [ + "mixed", + ], +) def test_serving( server_process, input_data_items, @@ -105,9 +130,9 @@ def test_serving( target_data_items, targets, ): - visitor, queue_manager, storage, consumer = components + visitor, _, storage, consumer = components - queue_manager.clear() + consumer.queue_manager.clear() storage.clear_bucket(bucket_name) if storage_item_has_metadata: @@ -123,14 +148,11 @@ def test_serving( for data, message in adorned_data_metadata_packs: storage.put_object(**get_object_descriptor(message), data=gzip.compress(data)) - queue_manager.publish_request(message) + consumer.queue_manager.publish_request(message) - reqs = consumer.consume(inactivity_timeout=5) + consumer.consume_and_publish(n=len(adorned_data_metadata_packs)) - for _, req in zip(adorned_data_metadata_packs, reqs): - queue_manager.publish_response(req, visitor) - - names_of_uploaded_files = lpluck("responseFile", queue_manager.output_queue.to_list()) - uploaded_files = [storage.get_object(bucket_name, fn) for fn in names_of_uploaded_files] + names_of_uploaded_files = lpluck("responseFile", consumer.queue_manager.output_queue.to_list()) + uploaded_files = starmap(storage.get_object, zip(repeat(bucket_name), names_of_uploaded_files)) outputs = sorted(chain(*map(decode, uploaded_files)), key=itemgetter(0)) assert outputs == targets diff --git a/test/queue/queue_manager_mock.py b/test/queue/queue_manager_mock.py index fe50a8f..d904ac1 100644 --- a/test/queue/queue_manager_mock.py +++ b/test/queue/queue_manager_mock.py @@ -1,3 +1,5 @@ +from itertools import islice + from pyinfra.queue.queue_manager.queue_manager import QueueManager, QueueHandle from test.queue.queue_mock import QueueMock @@ -33,8 +35,8 @@ class QueueManagerMock(QueueManager): while self._input_queue: yield self.pull_request() - def consume_and_publish(self, callback): - for message in self.consume(): + def consume_and_publish(self, callback, n=None): + for message in islice(self.consume(), n): self.publish_response(message, callback) def basic_consume_and_publish(self, callback):