import logging from functools import lru_cache from funcy import rcompose, omit, merge, lmap from pyinfra.config import CONFIG from pyinfra.exceptions import AnalysisFailure from pyinfra.queue.consumer import Consumer from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager from pyinfra.server.client_pipeline import ClientPipeline from pyinfra.server.dispatcher.dispatchers.rest import RestDispatcher from pyinfra.server.interpreter.interpreters.rest_callback import RestPickupStreamer from pyinfra.server.packer.packers.rest import RestPacker from pyinfra.server.receiver.receivers.rest import RestReceiver from pyinfra.storage import storages from pyinfra.visitor import QueueVisitor, AggregationStorageStrategy logger = logging.getLogger(__name__) logger.setLevel(logging.ERROR) @lru_cache(maxsize=None) def get_consumer(callback=None): callback = callback or get_callback() return Consumer(get_visitor(callback), get_queue_manager()) @lru_cache(maxsize=None) def get_visitor(callback): return QueueVisitor(get_storage(), callback, get_response_strategy()) @lru_cache(maxsize=None) def get_queue_manager(): return PikaQueueManager(CONFIG.rabbitmq.queues.input, CONFIG.rabbitmq.queues.output) @lru_cache(maxsize=None) def get_storage(): return storages.get_storage(CONFIG.storage.backend) @lru_cache(maxsize=None) def get_callback(analysis_base_url=None): analysis_base_url = analysis_base_url or CONFIG.rabbitmq.callback.analysis_endpoint return Callback(analysis_base_url) @lru_cache(maxsize=None) def get_response_strategy(storage=None): return AggregationStorageStrategy(storage or get_storage()) class Callback: def __init__(self, base_url): self.base_url = base_url self.endpoint2pipeline = {} def __make_endpoint(self, operation): return f"{self.base_url}/{operation}" def __get_pipeline(self, endpoint): if endpoint in self.endpoint2pipeline: pipeline = self.endpoint2pipeline[endpoint] else: pipeline = get_pipeline(endpoint) self.endpoint2pipeline[endpoint] = pipeline return pipeline @staticmethod def __run_pipeline(pipeline, body): """ TODO: since data and metadata are passed as singletons, there is no buffering and hence no batching happening within the pipeline. However, the queue acknowledgment logic needs to be changed in order to facilitate passing non-singletons, to only ack a message, once a response is pulled from the output queue of the pipeline. Probably the pipeline return value needs to contains the queue message frame (or so), in order for the queue manager to tell which message to ack. TODO: casting list (lmap) on `analysis_response_stream` is a temporary solution, while the client pipeline operates on singletons ([data], [metadata]). """ def combine_storage_item_metadata_with_queue_message_metadata(body): return merge(body["metadata"], omit(body, ["data", "metadata"])) def remove_queue_message_metadata(result): metadata = omit(result["metadata"], queue_message_keys(body)) return {**result, "metadata": metadata} def queue_message_keys(body): return {*body.keys()}.difference({"data", "metadata"}) try: data = body["data"] metadata = combine_storage_item_metadata_with_queue_message_metadata(body) analysis_response_stream = pipeline([data], [metadata]) analysis_response_stream = lmap(remove_queue_message_metadata, analysis_response_stream) return analysis_response_stream except Exception as err: logger.error(err) raise AnalysisFailure from err def __call__(self, body: dict): operation = body.get("operation", "submit") endpoint = self.__make_endpoint(operation) pipeline = self.__get_pipeline(endpoint) try: logging.debug(f"Requesting analysis from {endpoint}...") return self.__run_pipeline(pipeline, body) except AnalysisFailure: logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.") @lru_cache(maxsize=None) def get_pipeline(endpoint): return ClientPipeline( RestPacker(), RestDispatcher(endpoint), RestReceiver(), rcompose(RestPickupStreamer(), RestReceiver()) )