refactoring: move
This commit is contained in:
parent
47f1d77c03
commit
6945760045
95
pyinfra/default_objects.py
Normal file
95
pyinfra/default_objects.py
Normal file
@ -0,0 +1,95 @@
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
|
||||
from funcy import rcompose
|
||||
|
||||
from pyinfra.config import CONFIG
|
||||
from pyinfra.exceptions import AnalysisFailure
|
||||
from pyinfra.queue.consumer import Consumer
|
||||
from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager
|
||||
from pyinfra.server.client_pipeline import ClientPipeline
|
||||
from pyinfra.server.dispatcher.dispatchers.rest import RestDispatcher
|
||||
from pyinfra.server.interpreter.interpreters.rest_callback import RestPickupStreamer
|
||||
from pyinfra.server.packer.packers.rest import RestPacker
|
||||
from pyinfra.server.receiver.receivers.rest import RestReceiver
|
||||
from pyinfra.storage import storages
|
||||
from pyinfra.visitor import StorageStrategy, QueueVisitor
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_consumer():
|
||||
return Consumer(get_visitor(), get_queue_manager())
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_visitor():
|
||||
return QueueVisitor(get_storage(), get_callback(), get_response_strategy())
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_queue_manager():
|
||||
return PikaQueueManager(CONFIG.rabbitmq.queues.input, CONFIG.rabbitmq.queues.output)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_storage():
|
||||
return storages.get_storage(CONFIG.storage.backend)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_callback():
|
||||
return make_callback(CONFIG.rabbitmq.callback.analysis_endpoint)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_response_strategy():
|
||||
return StorageStrategy(get_storage())
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def make_callback(analysis_endpoint):
|
||||
def callback(data, metadata: dict):
|
||||
try:
|
||||
logging.debug(f"Requesting analysis from {endpoint}...")
|
||||
analysis_response_stream = pipeline([data], [metadata])
|
||||
return analysis_response_stream
|
||||
except Exception as err:
|
||||
logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.")
|
||||
raise AnalysisFailure() from err
|
||||
|
||||
endpoint = f"{analysis_endpoint}/submit"
|
||||
pipeline = get_pipeline(endpoint)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_pipeline(endpoint):
|
||||
return ClientPipeline(
|
||||
RestPacker(), RestDispatcher(endpoint), RestReceiver(), rcompose(RestPickupStreamer(), RestReceiver())
|
||||
)
|
||||
|
||||
|
||||
# def make_callback(analysis_endpoint):
|
||||
# def callback(message: dict):
|
||||
# def perform_operation(operation):
|
||||
# endpoint = f"{analysis_endpoint}/{operation}"
|
||||
# try:
|
||||
# logging.debug(f"Requesting analysis from {endpoint}...")
|
||||
# analysis_response = requests.post(endpoint, data=message["data"])
|
||||
# analysis_response.raise_for_status()
|
||||
# analysis_response = analysis_response.json()
|
||||
# logging.debug(f"Received response.")
|
||||
# return analysis_response
|
||||
# except Exception as err:
|
||||
# logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.")
|
||||
# raise AnalysisFailure() from err
|
||||
#
|
||||
# operations = message.get("operations", ["/"])
|
||||
# results = map(perform_operation, operations)
|
||||
# result = dict(zip(operations, results))
|
||||
# if list(result.keys()) == ["/"]:
|
||||
# result = list(result.values())[0]
|
||||
# return result
|
||||
#
|
||||
# return callback
|
||||
@ -1,3 +1,4 @@
|
||||
import base64
|
||||
from _operator import itemgetter
|
||||
from itertools import starmap
|
||||
from typing import Iterable
|
||||
@ -5,13 +6,16 @@ from typing import Iterable
|
||||
from funcy import compose
|
||||
|
||||
from pyinfra.utils.func import starlift, lift
|
||||
from test.utils.server import bytes_to_string, string_to_bytes
|
||||
|
||||
|
||||
def pack_data_and_metadata_for_rest_transfer(data: Iterable, metadata: Iterable):
|
||||
yield from starmap(pack, zip(data, metadata))
|
||||
|
||||
|
||||
def unpack_fn_pack(fn):
|
||||
return compose(starlift(pack), fn, lift(unpack))
|
||||
|
||||
|
||||
def pack(data: bytes, metadata: dict):
|
||||
package = {"data": bytes_to_string(data), "metadata": metadata}
|
||||
return package
|
||||
@ -22,5 +26,9 @@ def unpack(package):
|
||||
return string_to_bytes(data), metadata
|
||||
|
||||
|
||||
def unpack_fn_pack(fn):
|
||||
return compose(starlift(pack), fn, lift(unpack))
|
||||
def bytes_to_string(data: bytes) -> str:
|
||||
return base64.b64encode(data).decode()
|
||||
|
||||
|
||||
def string_to_bytes(data: str) -> bytes:
|
||||
return base64.b64decode(data.encode())
|
||||
|
||||
55
src/serve.py
55
src/serve.py
@ -1,42 +1,22 @@
|
||||
import logging
|
||||
from multiprocessing import Process
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.config import CONFIG
|
||||
from pyinfra.exceptions import AnalysisFailure, ConsumerError
|
||||
from pyinfra.default_objects import get_consumer
|
||||
from pyinfra.exceptions import ConsumerError
|
||||
from pyinfra.flask import run_probing_webserver, set_up_probing_webserver
|
||||
from pyinfra.queue.consumer import Consumer
|
||||
from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager
|
||||
from pyinfra.storage.storages import get_storage
|
||||
from pyinfra.utils.banner import show_banner
|
||||
from pyinfra.visitor import QueueVisitor, StorageStrategy
|
||||
|
||||
|
||||
def make_callback(analysis_endpoint):
|
||||
def callback(message):
|
||||
def perform_operation(operation):
|
||||
endpoint = f"{analysis_endpoint}/{operation}"
|
||||
try:
|
||||
logging.debug(f"Requesting analysis from {endpoint}...")
|
||||
analysis_response = requests.post(endpoint, data=message["data"])
|
||||
analysis_response.raise_for_status()
|
||||
analysis_response = analysis_response.json()
|
||||
logging.debug(f"Received response.")
|
||||
return analysis_response
|
||||
except Exception as err:
|
||||
logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.")
|
||||
raise AnalysisFailure() from err
|
||||
|
||||
operations = message.get("operations", ["/"])
|
||||
results = map(perform_operation, operations)
|
||||
result = dict(zip(operations, results))
|
||||
if list(result.keys()) == ["/"]:
|
||||
result = list(result.values())[0]
|
||||
return result
|
||||
|
||||
return callback
|
||||
@retry(ConsumerError, tries=3, delay=5, jitter=(1, 3))
|
||||
def consume():
|
||||
consumer = get_consumer()
|
||||
try:
|
||||
consumer.basic_consume_and_publish()
|
||||
except Exception as err:
|
||||
raise ConsumerError() from err
|
||||
|
||||
|
||||
def main():
|
||||
@ -46,21 +26,6 @@ def main():
|
||||
logging.info("Starting webserver...")
|
||||
webserver.start()
|
||||
|
||||
callback = make_callback(CONFIG.rabbitmq.callback.analysis_endpoint)
|
||||
storage = get_storage(CONFIG.storage.backend)
|
||||
response_strategy = StorageStrategy(storage)
|
||||
visitor = QueueVisitor(storage, callback, response_strategy)
|
||||
|
||||
queue_manager = PikaQueueManager(CONFIG.rabbitmq.queues.input, CONFIG.rabbitmq.queues.output)
|
||||
|
||||
@retry(ConsumerError, tries=3, delay=5, jitter=(1, 3))
|
||||
def consume():
|
||||
consumer = Consumer(visitor, queue_manager)
|
||||
try:
|
||||
consumer.basic_consume_and_publish()
|
||||
except Exception as err:
|
||||
raise ConsumerError() from err
|
||||
|
||||
try:
|
||||
consume()
|
||||
except KeyboardInterrupt:
|
||||
@ -79,4 +44,4 @@ if __name__ == "__main__":
|
||||
logging.getLogger("flask").setLevel(logging.ERROR)
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
|
||||
main()
|
||||
main()
|
||||
@ -2,9 +2,8 @@ import pytest
|
||||
from funcy import compose, lzip
|
||||
|
||||
from pyinfra.server.packer.packers.identity import bundle
|
||||
from pyinfra.server.packing import pack, unpack
|
||||
from pyinfra.server.packing import pack, unpack, bytes_to_string
|
||||
from pyinfra.utils.func import lstarlift
|
||||
from test.utils.server import bytes_to_string
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_items", [0, 2])
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
import base64
|
||||
|
||||
|
||||
def bytes_to_string(data: bytes) -> str:
|
||||
return base64.b64encode(data).decode()
|
||||
|
||||
|
||||
def string_to_bytes(data: str) -> bytes:
|
||||
return base64.b64decode(data.encode())
|
||||
Loading…
x
Reference in New Issue
Block a user