import logging from multiprocessing import Process import pika from pyinfra.callback import ( make_retry_callback_for_output_queue, make_retry_callback, make_callback_for_output_queue, ) from pyinfra.config import CONFIG from pyinfra.consume import consume, ConsumerError from pyinfra.core import make_payload_processor, make_storage_data_loader, make_analyzer from pyinfra.exceptions import UnknownStorageBackend from pyinfra.flask import run_probing_webserver, set_up_probing_webserver from pyinfra.storage.storages import get_azure_storage, get_s3_storage def get_storage(): storage_backend = CONFIG.storage.backend if storage_backend == "s3": storage = get_s3_storage() elif storage_backend == "azure": storage = get_azure_storage() else: raise UnknownStorageBackend(f"Unknown storage backend '{storage_backend}'.") return storage def republish(channel, body, n_current_attempts): channel.basic_publish( exchange="", routing_key=CONFIG.rabbitmq.queues.input, body=body, properties=pika.BasicProperties(headers={"x-retry-count": n_current_attempts}), ) def make_callback(): load_data = make_storage_data_loader(get_storage(), CONFIG.storage.bucket) analyze_file = make_analyzer(CONFIG.rabbitmq.callback.analysis_endpoint) json_wrapped_body_processor = make_payload_processor(load_data, analyze_file) if CONFIG.rabbitmq.callback.retry.enabled: retry_callback = make_retry_callback(republish, max_attempts=CONFIG.rabbitmq.callback.retry.max_attempts) callback = make_retry_callback_for_output_queue( json_wrapped_body_processor=json_wrapped_body_processor, output_queue_name=CONFIG.rabbitmq.queues.output, retry_callback=retry_callback, ) else: callback = make_callback_for_output_queue( json_wrapped_body_processor=json_wrapped_body_processor, output_queue_name=CONFIG.rabbitmq.queues.output ) return callback def main(): # TODO: implement meaningful checks webserver = Process(target=run_probing_webserver, args=(set_up_probing_webserver(),)) logging.info("Starting webserver...") webserver.start() try: consume(CONFIG.rabbitmq.queues.input, make_callback()) except KeyboardInterrupt: pass except ConsumerError: webserver.terminate() raise webserver.join() if __name__ == "__main__": logging_level = CONFIG.service.logging_level logging.basicConfig(level=logging_level) logging.getLogger("pika").setLevel(logging.ERROR) logging.getLogger("flask").setLevel(logging.ERROR) logging.getLogger("urllib3").setLevel(logging.ERROR) main()