import logging from multiprocessing import Process import requests from retry import retry from pyinfra.config import CONFIG, make_art from pyinfra.exceptions import AnalysisFailure, ConsumerError from pyinfra.flask import run_probing_webserver, set_up_probing_webserver from pyinfra.queue.consumer import Consumer from pyinfra.queue.queue_manager.pika_queue_manager import PikaQueueManager from pyinfra.storage.storages import get_storage from pyinfra.visitor import QueueVisitor, StorageStrategy def make_callback(analysis_endpoint): def callback(message): def perform_operation(operation): endpoint = f"{analysis_endpoint}/{operation}" try: logging.debug(f"Requesting analysis from {endpoint}...") analysis_response = requests.post(endpoint, data=message["data"]) analysis_response.raise_for_status() analysis_response = analysis_response.json() logging.debug(f"Received response.") return analysis_response except Exception as err: logging.warning(f"Exception caught when calling analysis endpoint {endpoint}.") raise AnalysisFailure() from err operations = message.get("operations", ["/"]) results = map(perform_operation, operations) result = dict(zip(operations, results)) if list(result.keys()) == ["/"]: result = list(result.values())[0] return result return callback def main(): webserver = Process(target=run_probing_webserver, args=(set_up_probing_webserver(),)) logging.info(make_art()) logging.info("Starting webserver...") webserver.start() callback = make_callback(CONFIG.rabbitmq.callback.analysis_endpoint) storage = get_storage(CONFIG.storage.backend) response_strategy = StorageStrategy(storage) visitor = QueueVisitor(storage, callback, response_strategy) queue_manager = PikaQueueManager(CONFIG.rabbitmq.queues.input, CONFIG.rabbitmq.queues.output) @retry(ConsumerError, tries=3, delay=5, jitter=(1, 3)) def consume(): try: consumer = Consumer(visitor, queue_manager) consumer.consume_and_publish() except Exception as err: raise ConsumerError from err try: consume() except KeyboardInterrupt: pass except ConsumerError: webserver.terminate() raise webserver.join() if __name__ == "__main__": logging_level = CONFIG.service.logging_level logging.basicConfig(level=logging_level) logging.getLogger("pika").setLevel(logging.ERROR) logging.getLogger("flask").setLevel(logging.ERROR) logging.getLogger("urllib3").setLevel(logging.ERROR) main()