pyinfra/pyinfra/default_objects.py
2022-06-01 16:00:46 +02:00

123 lines
4.4 KiB
Python

import logging
from functools import lru_cache
from funcy import rcompose, omit, merge, lmap
from pyinfra.config import CONFIG
from pyinfra.exceptions import AnalysisFailure
from pyinfra.queue.consumer import Consumer
from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager
from pyinfra.server.client_pipeline import ClientPipeline
from pyinfra.server.dispatcher.dispatchers.rest import RestDispatcher
from pyinfra.server.interpreter.interpreters.rest_callback import RestPickupStreamer
from pyinfra.server.packer.packers.rest import RestPacker
from pyinfra.server.receiver.receivers.rest import RestReceiver
from pyinfra.storage import storages
from pyinfra.visitor import QueueVisitor, AggregationStorageStrategy
logger = logging.getLogger(__name__)
logger.setLevel(logging.ERROR)
@lru_cache(maxsize=None)
def get_consumer(callback=None):
callback = callback or get_callback()
return Consumer(get_visitor(callback), get_queue_manager())
@lru_cache(maxsize=None)
def get_visitor(callback):
return QueueVisitor(get_storage(), callback, get_response_strategy())
@lru_cache(maxsize=None)
def get_queue_manager():
return PikaQueueManager(CONFIG.rabbitmq.queues.input, CONFIG.rabbitmq.queues.output)
@lru_cache(maxsize=None)
def get_storage():
return storages.get_storage(CONFIG.storage.backend)
@lru_cache(maxsize=None)
def get_callback(analysis_base_url=None):
analysis_base_url = analysis_base_url or CONFIG.rabbitmq.callback.analysis_endpoint
return Callback(analysis_base_url)
@lru_cache(maxsize=None)
def get_response_strategy(storage=None):
return AggregationStorageStrategy(storage or get_storage())
class Callback:
def __init__(self, base_url):
self.base_url = base_url
self.endpoint2pipeline = {}
def __make_endpoint(self, operation):
return f"{self.base_url}/{operation}"
def __get_pipeline(self, endpoint):
if endpoint in self.endpoint2pipeline:
pipeline = self.endpoint2pipeline[endpoint]
else:
pipeline = get_pipeline(endpoint)
self.endpoint2pipeline[endpoint] = pipeline
return pipeline
@staticmethod
def __run_pipeline(pipeline, body):
"""
TODO: since data and metadata are passed as singletons, there is no buffering and hence no batching happening
within the pipeline. However, the queue acknowledgment logic needs to be changed in order to facilitate
passing non-singletons, to only ack a message, once a response is pulled from the output queue of the
pipeline. Probably the pipeline return value needs to contains the queue message frame (or so), in order for
the queue manager to tell which message to ack.
TODO: casting list (lmap) on `analysis_response_stream` is a temporary solution, while the client pipeline
operates on singletons ([data], [metadata]).
"""
def combine_storage_item_metadata_with_queue_message_metadata(body):
return merge(body["metadata"], omit(body, ["data", "metadata"]))
def remove_queue_message_metadata(result):
metadata = omit(result["metadata"], queue_message_keys(body))
return {**result, "metadata": metadata}
def queue_message_keys(body):
return {*body.keys()}.difference({"data", "metadata"})
try:
data = body["data"]
metadata = combine_storage_item_metadata_with_queue_message_metadata(body)
analysis_response_stream = pipeline([data], [metadata])
analysis_response_stream = lmap(remove_queue_message_metadata, analysis_response_stream)
return analysis_response_stream
except Exception as err:
logger.error(err)
raise AnalysisFailure from err
def __call__(self, body: dict):
operation = body.get("operation", "submit")
endpoint = self.__make_endpoint(operation)
pipeline = self.__get_pipeline(endpoint)
try:
logging.debug(f"Requesting analysis from {endpoint}...")
return self.__run_pipeline(pipeline, body)
except AnalysisFailure:
logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.")
@lru_cache(maxsize=None)
def get_pipeline(endpoint):
return ClientPipeline(
RestPacker(), RestDispatcher(endpoint), RestReceiver(), rcompose(RestPickupStreamer(), RestReceiver())
)