diff --git a/pyinfra/default_objects.py b/pyinfra/default_objects.py new file mode 100644 index 0000000..7c09792 --- /dev/null +++ b/pyinfra/default_objects.py @@ -0,0 +1,95 @@ +import logging +from functools import lru_cache + +from funcy import rcompose + +from pyinfra.config import CONFIG +from pyinfra.exceptions import AnalysisFailure +from pyinfra.queue.consumer import Consumer +from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager +from pyinfra.server.client_pipeline import ClientPipeline +from pyinfra.server.dispatcher.dispatchers.rest import RestDispatcher +from pyinfra.server.interpreter.interpreters.rest_callback import RestPickupStreamer +from pyinfra.server.packer.packers.rest import RestPacker +from pyinfra.server.receiver.receivers.rest import RestReceiver +from pyinfra.storage import storages +from pyinfra.visitor import StorageStrategy, QueueVisitor + + +@lru_cache(maxsize=None) +def get_consumer(): + return Consumer(get_visitor(), get_queue_manager()) + + +@lru_cache(maxsize=None) +def get_visitor(): + return QueueVisitor(get_storage(), get_callback(), get_response_strategy()) + + +@lru_cache(maxsize=None) +def get_queue_manager(): + return PikaQueueManager(CONFIG.rabbitmq.queues.input, CONFIG.rabbitmq.queues.output) + + +@lru_cache(maxsize=None) +def get_storage(): + return storages.get_storage(CONFIG.storage.backend) + + +@lru_cache(maxsize=None) +def get_callback(): + return make_callback(CONFIG.rabbitmq.callback.analysis_endpoint) + + +@lru_cache(maxsize=None) +def get_response_strategy(): + return StorageStrategy(get_storage()) + + +@lru_cache(maxsize=None) +def make_callback(analysis_endpoint): + def callback(data, metadata: dict): + try: + logging.debug(f"Requesting analysis from {endpoint}...") + analysis_response_stream = pipeline([data], [metadata]) + return analysis_response_stream + except Exception as err: + logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.") + raise AnalysisFailure() from err + + endpoint = f"{analysis_endpoint}/submit" + pipeline = get_pipeline(endpoint) + + return callback + + +@lru_cache(maxsize=None) +def get_pipeline(endpoint): + return ClientPipeline( + RestPacker(), RestDispatcher(endpoint), RestReceiver(), rcompose(RestPickupStreamer(), RestReceiver()) + ) + + +# def make_callback(analysis_endpoint): +# def callback(message: dict): +# def perform_operation(operation): +# endpoint = f"{analysis_endpoint}/{operation}" +# try: +# logging.debug(f"Requesting analysis from {endpoint}...") +# analysis_response = requests.post(endpoint, data=message["data"]) +# analysis_response.raise_for_status() +# analysis_response = analysis_response.json() +# logging.debug(f"Received response.") +# return analysis_response +# except Exception as err: +# logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.") +# raise AnalysisFailure() from err +# +# operations = message.get("operations", ["/"]) +# results = map(perform_operation, operations) +# result = dict(zip(operations, results)) +# if list(result.keys()) == ["/"]: +# result = list(result.values())[0] +# return result +# +# return callback \ No newline at end of file diff --git a/pyinfra/server/packing.py b/pyinfra/server/packing.py index 32624e9..820b24d 100644 --- a/pyinfra/server/packing.py +++ b/pyinfra/server/packing.py @@ -1,3 +1,4 @@ +import base64 from _operator import itemgetter from itertools import starmap from typing import Iterable @@ -5,13 +6,16 @@ from typing import Iterable from funcy import compose from pyinfra.utils.func import starlift, lift -from test.utils.server import bytes_to_string, string_to_bytes def pack_data_and_metadata_for_rest_transfer(data: Iterable, metadata: Iterable): yield from starmap(pack, zip(data, metadata)) +def unpack_fn_pack(fn): + return compose(starlift(pack), fn, lift(unpack)) + + def pack(data: bytes, metadata: dict): package = {"data": bytes_to_string(data), "metadata": metadata} return package @@ -22,5 +26,9 @@ def unpack(package): return string_to_bytes(data), metadata -def unpack_fn_pack(fn): - return compose(starlift(pack), fn, lift(unpack)) +def bytes_to_string(data: bytes) -> str: + return base64.b64encode(data).decode() + + +def string_to_bytes(data: str) -> bytes: + return base64.b64decode(data.encode()) diff --git a/src/serve.py b/src/serve.py index f670e74..4c69f7b 100644 --- a/src/serve.py +++ b/src/serve.py @@ -1,42 +1,22 @@ import logging from multiprocessing import Process -import requests from retry import retry from pyinfra.config import CONFIG -from pyinfra.exceptions import AnalysisFailure, ConsumerError +from pyinfra.default_objects import get_consumer +from pyinfra.exceptions import ConsumerError from pyinfra.flask import run_probing_webserver, set_up_probing_webserver -from pyinfra.queue.consumer import Consumer -from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager -from pyinfra.storage.storages import get_storage from pyinfra.utils.banner import show_banner -from pyinfra.visitor import QueueVisitor, StorageStrategy -def make_callback(analysis_endpoint): - def callback(message): - def perform_operation(operation): - endpoint = f"{analysis_endpoint}/{operation}" - try: - logging.debug(f"Requesting analysis from {endpoint}...") - analysis_response = requests.post(endpoint, data=message["data"]) - analysis_response.raise_for_status() - analysis_response = analysis_response.json() - logging.debug(f"Received response.") - return analysis_response - except Exception as err: - logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.") - raise AnalysisFailure() from err - - operations = message.get("operations", ["/"]) - results = map(perform_operation, operations) - result = dict(zip(operations, results)) - if list(result.keys()) == ["/"]: - result = list(result.values())[0] - return result - - return callback +@retry(ConsumerError, tries=3, delay=5, jitter=(1, 3)) +def consume(): + consumer = get_consumer() + try: + consumer.basic_consume_and_publish() + except Exception as err: + raise ConsumerError() from err def main(): @@ -46,21 +26,6 @@ def main(): logging.info("Starting webserver...") webserver.start() - callback = make_callback(CONFIG.rabbitmq.callback.analysis_endpoint) - storage = get_storage(CONFIG.storage.backend) - response_strategy = StorageStrategy(storage) - visitor = QueueVisitor(storage, callback, response_strategy) - - queue_manager = PikaQueueManager(CONFIG.rabbitmq.queues.input, CONFIG.rabbitmq.queues.output) - - @retry(ConsumerError, tries=3, delay=5, jitter=(1, 3)) - def consume(): - consumer = Consumer(visitor, queue_manager) - try: - consumer.basic_consume_and_publish() - except Exception as err: - raise ConsumerError() from err - try: consume() except KeyboardInterrupt: @@ -79,4 +44,4 @@ if __name__ == "__main__": logging.getLogger("flask").setLevel(logging.ERROR) logging.getLogger("urllib3").setLevel(logging.ERROR) - main() + main() \ No newline at end of file diff --git a/test/unit_tests/server/utils.py b/test/unit_tests/server/utils.py index a398480..181fadf 100644 --- a/test/unit_tests/server/utils.py +++ b/test/unit_tests/server/utils.py @@ -2,9 +2,8 @@ import pytest from funcy import compose, lzip from pyinfra.server.packer.packers.identity import bundle -from pyinfra.server.packing import pack, unpack +from pyinfra.server.packing import pack, unpack, bytes_to_string from pyinfra.utils.func import lstarlift -from test.utils.server import bytes_to_string @pytest.mark.parametrize("n_items", [0, 2]) diff --git a/test/utils/server.py b/test/utils/server.py deleted file mode 100644 index 4e81c1e..0000000 --- a/test/utils/server.py +++ /dev/null @@ -1,9 +0,0 @@ -import base64 - - -def bytes_to_string(data: bytes) -> str: - return base64.b64encode(data).decode() - - -def string_to_bytes(data: str) -> bytes: - return base64.b64decode(data.encode())