123 lines
4.4 KiB
Python
123 lines
4.4 KiB
Python
import logging
|
|
from functools import lru_cache
|
|
|
|
from funcy import rcompose, omit, merge, lmap
|
|
|
|
from pyinfra.config import CONFIG
|
|
from pyinfra.exceptions import AnalysisFailure
|
|
from pyinfra.queue.consumer import Consumer
|
|
from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager
|
|
from pyinfra.server.client_pipeline import ClientPipeline
|
|
from pyinfra.server.dispatcher.dispatchers.rest import RestDispatcher
|
|
from pyinfra.server.interpreter.interpreters.rest_callback import RestPickupStreamer
|
|
from pyinfra.server.packer.packers.rest import RestPacker
|
|
from pyinfra.server.receiver.receivers.rest import RestReceiver
|
|
from pyinfra.storage import storages
|
|
from pyinfra.visitor import QueueVisitor, AggregationStorageStrategy
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.ERROR)
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_consumer(callback=None):
|
|
callback = callback or get_callback()
|
|
return Consumer(get_visitor(callback), get_queue_manager())
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_visitor(callback):
|
|
return QueueVisitor(get_storage(), callback, get_response_strategy())
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_queue_manager():
|
|
return PikaQueueManager(CONFIG.rabbitmq.queues.input, CONFIG.rabbitmq.queues.output)
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_storage():
|
|
return storages.get_storage(CONFIG.storage.backend)
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_callback(analysis_base_url=None):
|
|
analysis_base_url = analysis_base_url or CONFIG.rabbitmq.callback.analysis_endpoint
|
|
return Callback(analysis_base_url)
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_response_strategy(storage=None):
|
|
return AggregationStorageStrategy(storage or get_storage())
|
|
|
|
|
|
class Callback:
|
|
def __init__(self, base_url):
|
|
self.base_url = base_url
|
|
self.endpoint2pipeline = {}
|
|
|
|
def __make_endpoint(self, operation):
|
|
return f"{self.base_url}/{operation}"
|
|
|
|
def __get_pipeline(self, endpoint):
|
|
if endpoint in self.endpoint2pipeline:
|
|
pipeline = self.endpoint2pipeline[endpoint]
|
|
|
|
else:
|
|
pipeline = get_pipeline(endpoint)
|
|
self.endpoint2pipeline[endpoint] = pipeline
|
|
|
|
return pipeline
|
|
|
|
@staticmethod
|
|
def __run_pipeline(pipeline, body):
|
|
"""
|
|
TODO: since data and metadata are passed as singletons, there is no buffering and hence no batching happening
|
|
within the pipeline. However, the queue acknowledgment logic needs to be changed in order to facilitate
|
|
passing non-singletons, to only ack a message, once a response is pulled from the output queue of the
|
|
pipeline. Probably the pipeline return value needs to contains the queue message frame (or so), in order for
|
|
the queue manager to tell which message to ack.
|
|
|
|
TODO: casting list (lmap) on `analysis_response_stream` is a temporary solution, while the client pipeline
|
|
operates on singletons ([data], [metadata]).
|
|
"""
|
|
|
|
def combine_storage_item_metadata_with_queue_message_metadata(body):
|
|
return merge(body["metadata"], omit(body, ["data", "metadata"]))
|
|
|
|
def remove_queue_message_metadata(result):
|
|
metadata = omit(result["metadata"], queue_message_keys(body))
|
|
return {**result, "metadata": metadata}
|
|
|
|
def queue_message_keys(body):
|
|
return {*body.keys()}.difference({"data", "metadata"})
|
|
|
|
try:
|
|
data = body["data"]
|
|
metadata = combine_storage_item_metadata_with_queue_message_metadata(body)
|
|
analysis_response_stream = pipeline([data], [metadata])
|
|
analysis_response_stream = lmap(remove_queue_message_metadata, analysis_response_stream)
|
|
return analysis_response_stream
|
|
|
|
except Exception as err:
|
|
logger.error(err)
|
|
raise AnalysisFailure from err
|
|
|
|
def __call__(self, body: dict):
|
|
operation = body.get("operation", "submit")
|
|
endpoint = self.__make_endpoint(operation)
|
|
pipeline = self.__get_pipeline(endpoint)
|
|
|
|
try:
|
|
logging.debug(f"Requesting analysis from {endpoint}...")
|
|
return self.__run_pipeline(pipeline, body)
|
|
except AnalysisFailure:
|
|
logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.")
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_pipeline(endpoint):
|
|
return ClientPipeline(
|
|
RestPacker(), RestDispatcher(endpoint), RestReceiver(), rcompose(RestPickupStreamer(), RestReceiver())
|
|
)
|