diff --git a/pyinfra/default_objects.py b/pyinfra/default_objects.py index 7c09792..d2dae37 100644 --- a/pyinfra/default_objects.py +++ b/pyinfra/default_objects.py @@ -1,7 +1,8 @@ import logging from functools import lru_cache +from operator import itemgetter -from funcy import rcompose +from funcy import rcompose, pluck from pyinfra.config import CONFIG from pyinfra.exceptions import AnalysisFailure @@ -13,17 +14,17 @@ from pyinfra.server.interpreter.interpreters.rest_callback import RestPickupStre from pyinfra.server.packer.packers.rest import RestPacker from pyinfra.server.receiver.receivers.rest import RestReceiver from pyinfra.storage import storages -from pyinfra.visitor import StorageStrategy, QueueVisitor +from pyinfra.visitor import QueueVisitor, AggregationStorageStrategy @lru_cache(maxsize=None) -def get_consumer(): - return Consumer(get_visitor(), get_queue_manager()) +def get_consumer(callback): + return Consumer(get_visitor(callback), get_queue_manager()) @lru_cache(maxsize=None) -def get_visitor(): - return QueueVisitor(get_storage(), get_callback(), get_response_strategy()) +def get_visitor(callback): + return QueueVisitor(get_storage(), callback, get_response_strategy()) @lru_cache(maxsize=None) @@ -37,27 +38,29 @@ def get_storage(): @lru_cache(maxsize=None) -def get_callback(): - return make_callback(CONFIG.rabbitmq.callback.analysis_endpoint) +def get_callback(analysis_endpoint=None): + analysis_endpoint = analysis_endpoint or CONFIG.rabbitmq.callback.analysis_endpoint + return make_callback(analysis_endpoint) @lru_cache(maxsize=None) def get_response_strategy(): - return StorageStrategy(get_storage()) + return AggregationStorageStrategy(get_storage()) @lru_cache(maxsize=None) def make_callback(analysis_endpoint): - def callback(data, metadata: dict): + def callback(body: dict): try: + data, metadata = itemgetter("data", "metadata")(body) logging.debug(f"Requesting analysis from {endpoint}...") - analysis_response_stream = pipeline([data], [metadata]) + analysis_response_stream = pluck("data", pipeline([data], [metadata])) return analysis_response_stream except Exception as err: logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.") raise AnalysisFailure() from err - endpoint = f"{analysis_endpoint}/submit" + endpoint = f"{analysis_endpoint}" pipeline = get_pipeline(endpoint) return callback @@ -85,11 +88,11 @@ def get_pipeline(endpoint): # logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.") # raise AnalysisFailure() from err # -# operations = message.get("operations", ["/"]) +# operations = message.get("operations", [""]) # results = map(perform_operation, operations) # result = dict(zip(operations, results)) # if list(result.keys()) == ["/"]: # result = list(result.values())[0] # return result # -# return callback \ No newline at end of file +# return callback diff --git a/pyinfra/queue/queue_manager/pika_queue_manager.py b/pyinfra/queue/queue_manager/pika_queue_manager.py index feecea6..af65545 100644 --- a/pyinfra/queue/queue_manager/pika_queue_manager.py +++ b/pyinfra/queue/queue_manager/pika_queue_manager.py @@ -6,6 +6,7 @@ import pika from pyinfra.config import CONFIG from pyinfra.exceptions import ProcessingFailure from pyinfra.queue.queue_manager.queue_manager import QueueHandle, QueueManager +from pyinfra.visitor import QueueVisitor logger = logging.getLogger("pika") logger.setLevel(logging.WARNING) @@ -98,7 +99,7 @@ class PikaQueueManager(QueueManager): logger.exception(f"Adding to dead letter queue: {body}") self.channel.basic_reject(delivery_tag=frame.delivery_tag, requeue=False) - def publish_response(self, message, callback, max_attempts=3): + def publish_response(self, message, visitor: QueueVisitor, max_attempts=3): logger.debug(f"Processing {message}.") @@ -107,7 +108,9 @@ class PikaQueueManager(QueueManager): n_attempts = get_n_previous_attempts(properties) + 1 try: - response = json.dumps(callback(json.loads(body))) + body = json.loads(body) + response = visitor(body) + response = json.dumps(response) self.channel.basic_publish("", self._output_queue, response.encode()) self.channel.basic_ack(frame.delivery_tag) except ProcessingFailure: @@ -126,14 +129,14 @@ class PikaQueueManager(QueueManager): logger.debug("Consuming") return self.channel.consume(self._input_queue, inactivity_timeout=inactivity_timeout) - def consume_and_publish(self, visitor): + def consume_and_publish(self, visitor: QueueVisitor): logger.info(f"Consuming input queue.") for message in self.consume(): self.publish_response(message, visitor) - def basic_consume_and_publish(self, visitor): + def basic_consume_and_publish(self, visitor: QueueVisitor): logger.info(f"Basic consuming input queue.") diff --git a/pyinfra/visitor.py b/pyinfra/visitor.py index e49a4cd..5c74418 100644 --- a/pyinfra/visitor.py +++ b/pyinfra/visitor.py @@ -2,11 +2,15 @@ import abc import gzip import json import logging +from collections import deque from operator import itemgetter -from typing import Callable +from typing import Callable, Generator + +from funcy import omit from pyinfra.config import CONFIG, parse_disjunction_string from pyinfra.exceptions import DataLoadingFailure +from pyinfra.server.packing import string_to_bytes from pyinfra.storage.storage import Storage @@ -57,6 +61,67 @@ class ForwardingStrategy(ResponseStrategy): return body +class DispatchCallback(abc.ABC): + @abc.abstractmethod + def __call__(self, payload): + pass + + +class IdentifierDispatchCallback(DispatchCallback): + def __init__(self): + self.identifier = None + + def has_new_identifier(self, metadata): + + identifier = ":".join(itemgetter("fileId", "dossierId")(metadata)) + + if not self.identifier: + self.identifier = identifier + + return identifier != self.identifier + + def __call__(self, payload): + return self.has_new_identifier(payload) + + +class AggregationStorageStrategy(ResponseStrategy): + def __init__(self, storage, merger: Callable = None, dispatch_callback: DispatchCallback = None): + self.storage = storage + self.merger = merger or list + self.dispatch_callback = dispatch_callback or IdentifierDispatchCallback() + self.buffer = deque() + + def put_object(self, data, metadata): + object_descriptor = get_response_object_descriptor(metadata) + self.storage.put_object(**object_descriptor, data=gzip.compress(string_to_bytes(data))) + + def merge_queue_items(self): + merged_buffer_content = self.merger(self.buffer) + self.buffer.clear() + return merged_buffer_content + + def upload_queue_items(self, metadata): + data = self.merge_queue_items() + self.put_object(data, metadata) + + def upload_or_aggregate(self, data, metadata): + + if isinstance(data, str): + self.put_object(data, metadata) + + else: + self.buffer.append(data) + if self.dispatch_callback(metadata): + self.upload_queue_items(metadata) + + def handle_response(self, payload, final=False): + metadata = omit(payload, ["data"]) + data = payload["data"] + for item in data: + self.upload_or_aggregate(item, metadata) + return metadata + + class QueueVisitor: def __init__(self, storage: Storage, callback: Callable, response_strategy): self.storage = storage @@ -79,11 +144,12 @@ class QueueVisitor: raise DataLoadingFailure() from err def process_data(self, data, body): - return self.callback({**body, "data": data}) + return self.callback({"data": data, "metadata": body}) def load_and_process(self, body): - data = self.process_data(self.load_data(body), body) - result_body = {**body, "data": data} + data_from_storage = self.load_data(body) + result = self.process_data(data_from_storage, body) + result_body = {"data": result, **body} return result_body def __call__(self, body):