diff --git a/pyinfra/default_objects.py b/pyinfra/default_objects.py index 523fde9..012df1c 100644 --- a/pyinfra/default_objects.py +++ b/pyinfra/default_objects.py @@ -3,7 +3,6 @@ from functools import lru_cache from funcy import rcompose, omit, merge, lmap, project -from pyinfra.config import CONFIG from pyinfra.exceptions import AnalysisFailure from pyinfra.queue.consumer import Consumer from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager @@ -14,64 +13,66 @@ from pyinfra.server.packer.packers.rest import RestPacker from pyinfra.server.receiver.receivers.rest import RestReceiver from pyinfra.storage import storages from pyinfra.visitor import QueueVisitor -from pyinfra.visitor.strategies.response.aggregation import AggregationStorageStrategy -from pyinfra.visitor.response_formatter.formatters.identity import IdentityResponseFormatter from pyinfra.visitor.response_formatter.formatters.default import DefaultResponseFormatter +from pyinfra.visitor.response_formatter.formatters.identity import IdentityResponseFormatter +from pyinfra.visitor.strategies.response.aggregation import AggregationStorageStrategy logger = logging.getLogger(__name__) @lru_cache(maxsize=None) -def get_consumer(callback=None): - callback = callback or get_callback() - return Consumer(get_visitor(callback), get_queue_manager()) +def get_component_factory(config): + return ComponentFactory(config) -@lru_cache(maxsize=None) -def get_visitor(callback): - return QueueVisitor( - storage=get_storage(), - callback=callback, - response_strategy=get_response_strategy(), - response_formatter=get_response_formatter(), - ) +class ComponentFactory: + def __init__(self, config): + self.config = config + @lru_cache(maxsize=None) + def get_consumer(self, callback=None): + callback = callback or self.get_callback() + return Consumer(self.get_visitor(callback), self.get_queue_manager()) -@lru_cache(maxsize=None) -def get_queue_manager(): - return PikaQueueManager(CONFIG.rabbitmq.queues.input, CONFIG.rabbitmq.queues.output) + @lru_cache(maxsize=None) + def get_visitor(self, callback): + return QueueVisitor( + storage=self.get_storage(), + callback=callback, + response_strategy=self.get_response_strategy(), + response_formatter=self.get_response_formatter(), + ) + @lru_cache(maxsize=None) + def get_queue_manager(self): + return PikaQueueManager(self.config.rabbitmq.queues.input, self.config.rabbitmq.queues.output) -@lru_cache(maxsize=None) -def get_storage(): - return storages.get_storage(CONFIG.storage.backend) + @lru_cache(maxsize=None) + def get_storage(self): + return storages.get_storage(self.config.storage.backend) + @lru_cache(maxsize=None) + def get_callback(self, analysis_base_url=None): + analysis_base_url = analysis_base_url or self.config.rabbitmq.callback.analysis_endpoint -@lru_cache(maxsize=None) -def get_callback(analysis_base_url=None): - analysis_base_url = analysis_base_url or CONFIG.rabbitmq.callback.analysis_endpoint + callback = Callback(analysis_base_url) - callback = Callback(analysis_base_url) + def wrapped(body): + body_repr = project(body, ["dossierId", "fileId", "pages", "images", "operation"]) + logger.info(f"Processing {body_repr}...") + return callback(body) - def wrapped(body): - body_repr = project(body, ["dossierId", "fileId", "pages", "images", "operation"]) - logger.info(f"Processing {body_repr}...") - return callback(body) + return wrapped - return wrapped + @lru_cache(maxsize=None) + def get_response_strategy(self, storage=None): + return AggregationStorageStrategy(storage or self.get_storage()) - -@lru_cache(maxsize=None) -def get_response_strategy(storage=None): - return AggregationStorageStrategy(storage or get_storage()) - - -@lru_cache(maxsize=None) -def get_response_formatter(): - return { - "default": DefaultResponseFormatter(), - "identity": IdentityResponseFormatter() - }[CONFIG.service.response_formatter] + @lru_cache(maxsize=None) + def get_response_formatter(self): + return {"default": DefaultResponseFormatter(), "identity": IdentityResponseFormatter()}[ + self.config.service.response_formatter + ] class Callback: diff --git a/src/serve.py b/src/serve.py index 2e52624..9608c73 100644 --- a/src/serve.py +++ b/src/serve.py @@ -4,7 +4,7 @@ from multiprocessing import Process from pyinfra.utils.retry import retry from pyinfra.config import CONFIG -from pyinfra.default_objects import get_consumer +from pyinfra.default_objects import ComponentFactory from pyinfra.exceptions import ConsumerError from pyinfra.flask import run_probing_webserver, set_up_probing_webserver from pyinfra.utils.banner import show_banner @@ -15,7 +15,7 @@ logger = logging.getLogger() @retry(ConsumerError) def consume(): try: - consumer = get_consumer() + consumer = ComponentFactory(CONFIG).get_consumer() consumer.basic_consume_and_publish() except Exception as err: logger.exception(err) diff --git a/test/config.yaml b/test/config.yaml index ba8404c..d88ccf1 100644 --- a/test/config.yaml +++ b/test/config.yaml @@ -24,5 +24,5 @@ webserver: mock_analysis_endpoint: "http://127.0.0.1:5000" -use_docker_fixture: True -logging: False \ No newline at end of file +use_docker_fixture: 0 +logging: 0 \ No newline at end of file diff --git a/test/integration_tests/serve_test.py b/test/integration_tests/serve_test.py index 97f931a..b0de1a0 100644 --- a/test/integration_tests/serve_test.py +++ b/test/integration_tests/serve_test.py @@ -6,13 +6,8 @@ from operator import itemgetter import pytest from funcy import compose, first, second, pluck, lflatten -from pyinfra.default_objects import ( - get_callback, - get_response_strategy, - get_consumer, - get_queue_manager, - get_storage, -) +from pyinfra.config import CONFIG +from pyinfra.default_objects import ComponentFactory, get_component_factory from pyinfra.queue.consumer import Consumer from pyinfra.server.packing import unpack, pack from pyinfra.utils.encoding import compress, decompress @@ -213,10 +208,12 @@ def components_type(request): @pytest.fixture def real_components(url, download_strategy): - callback = get_callback(url) - consumer = get_consumer(callback) - queue_manager = get_queue_manager() - storage = get_storage() + component_factory = get_component_factory(CONFIG) + + callback = component_factory.get_callback(url) + consumer = component_factory.get_consumer(callback) + queue_manager = component_factory.get_queue_manager() + storage = component_factory.get_storage() consumer.visitor.download_strategy = download_strategy return storage, queue_manager, consumer @@ -230,8 +227,13 @@ def download_strategy(many_to_n): @pytest.fixture def test_components(url, queue_manager, storage): - callback = get_callback(url) - visitor = QueueVisitor(storage=storage, callback=callback, response_strategy=get_response_strategy(storage)) + component_factory = ComponentFactory(CONFIG) + + visitor = QueueVisitor( + storage=storage, + callback=component_factory.get_callback(url), + response_strategy=component_factory.get_response_strategy(storage), + ) consumer = Consumer(visitor, queue_manager) return storage, queue_manager, consumer