From 879b80dc0f9310902e515be9607d11f4e77ce624 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Thu, 17 Feb 2022 12:47:08 +0100 Subject: [PATCH] refactoring: data no longer stored in files, but kept in memory the entire time --- pyinfra/callback.py | 47 +++++++++++++++++++- pyinfra/consume.py | 3 +- pyinfra/storage/azure_blob_storage.py | 14 ++++-- pyinfra/storage/minio.py | 14 ++++++ pyinfra/storage/storage.py | 4 ++ requirements.txt | 5 ++- src/serve.py | 64 ++++++++++++++++----------- 7 files changed, 119 insertions(+), 32 deletions(-) diff --git a/pyinfra/callback.py b/pyinfra/callback.py index fcda8bd..4546303 100644 --- a/pyinfra/callback.py +++ b/pyinfra/callback.py @@ -5,6 +5,49 @@ from flask import jsonify from pyinfra.rabbitmq import make_connection, make_channel, declare_queue + +# def make_request_processor(consumer): +# def get_n_previous_attempts(props): +# return 0 if props.headers is None else props.headers.get("x-retry-count", 0) +# +# def attempts_remain(n_attempts): +# return n_attempts < max_attempts and CONFIG.rabbitmq.retry.enabled +# +# def republish(ch, body, n_current_attempts): +# ch.basic_publish( +# exchange="", +# routing_key=CONFIG.rabbitmq.queues.input, +# body=body, +# properties=pika.BasicProperties(headers={"x-retry-count": n_current_attempts}), +# ) +# +# def on_request(ch, method, props, body): +# +# try: +# response = consumer(body) +# ch.basic_publish(exchange="", routing_key=CONFIG.rabbitmq.queues.output, body=response) +# ch.basic_ack(delivery_tag=method.delivery_tag) +# +# except Exception as e: +# +# n_attempts = get_n_previous_attempts(props) + 1 +# +# logger.error(f"Message failed to process {n_attempts}/{max_attempts} times. Error: {e}") +# if n_attempts == max_attempts: +# logger.exception(f"Adding to dead letter queue. Last exception: {e}") +# +# if attempts_remain(n_attempts): +# republish(ch, body, n_attempts) +# ch.basic_ack(delivery_tag=method.delivery_tag) +# +# else: +# ch.basic_reject(delivery_tag=method.delivery_tag, requeue=False) +# +# max_attempts = CONFIG.rabbitmq.retry.max_attempts +# +# return on_request + + def make_callback_for_output_queue(json_wrapped_body_processor, output_queue_name): connection = make_connection() @@ -12,7 +55,9 @@ def make_callback_for_output_queue(json_wrapped_body_processor, output_queue_nam declare_queue(channel, output_queue_name) def callback(channel, method, _, body): - channel.basic_publish(exchange="", routing_key=output_queue_name, body=json_wrapped_body_processor(body)) + + result = json_wrapped_body_processor(body) # TODO: retries and stuff; see above + channel.basic_publish(exchange="", routing_key=output_queue_name, body=result) channel.basic_ack(delivery_tag=method.delivery_tag) return callback diff --git a/pyinfra/consume.py b/pyinfra/consume.py index a074885..46729fc 100644 --- a/pyinfra/consume.py +++ b/pyinfra/consume.py @@ -14,11 +14,10 @@ class ConsumerError(Exception): # @retry(pika.exceptions.AMQPConnectionError, delay=5, jitter=(1, 3)) def consume(queue_name: str, on_message_callback: Callable): - logging.info("Starting mini-queue...") connection = make_connection() channel = make_channel(connection) declare_queue(channel, queue_name) - logging.info("Starting webserver...") + logging.info("Started infrastructure.") while True: try: diff --git a/pyinfra/storage/azure_blob_storage.py b/pyinfra/storage/azure_blob_storage.py index f3fa40b..54805c6 100644 --- a/pyinfra/storage/azure_blob_storage.py +++ b/pyinfra/storage/azure_blob_storage.py @@ -87,12 +87,20 @@ class AzureBlobStorageHandle(StorageHandle): def _StorageHandle__fget_object(self, container_name: str, object_name: str, target_path): + with open(target_path, "wb") as f: + blob_data = self.get_object(container_name, object_name) + blob_data.readinto(f) + + def get_object(self, object_name: str, container_name: str = None): + + if container_name is None: + container_name = self.default_container_name + container_client = self.__get_container_client(container_name) blob_client = container_client.get_blob_client(object_name) + blob_data = blob_client.download_blob() - with open(target_path, "wb") as f: - blob_data = blob_client.download_blob() - blob_data.readinto(f) + return blob_data def _StorageHandle__remove_file(self, folder: str, filename: str, container_name: str = None) -> None: """Removes a file from the store. diff --git a/pyinfra/storage/minio.py b/pyinfra/storage/minio.py index 3a6eed9..681283d 100644 --- a/pyinfra/storage/minio.py +++ b/pyinfra/storage/minio.py @@ -88,6 +88,20 @@ class MinioHandle(StorageHandle): def _StorageHandle__fget_object(self, container_name, object_name, target_path): self.client.fget_object(container_name, object_name, target_path) + def get_object(self, object_name, container_name=None): + if container_name is None: + container_name = self.default_container_name + + response = None + + try: + response = self.client.get_object(container_name, object_name) + return response.data + finally: + if response: + response.close() + response.release_conn() + def _StorageHandle__remove_file(self, folder: str, filename: str, container_name: str = None) -> None: """Removes a file from the store. diff --git a/pyinfra/storage/storage.py b/pyinfra/storage/storage.py index e1cae48..95c42b9 100644 --- a/pyinfra/storage/storage.py +++ b/pyinfra/storage/storage.py @@ -87,6 +87,10 @@ class StorageHandle: def __fget_object(self, *args, **kwargs): pass + @abc.abstractmethod + def get_object(self, *args, **kwargs): + pass + @staticmethod def __storage_path(path, folder: str = None): def path_to_filename(path): diff --git a/requirements.txt b/requirements.txt index 1fdef37..a0ffed1 100755 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,9 @@ envyaml==1.10.211231 minio==7.1.3 Flask==2.0.3 waitress==2.0.0 -tqdm==4.62.3 azure-core==1.22.1 azure-storage-blob==12.9.0 +requests==2.27.1 +# dev +docker-compose==1.29.2 +tqdm==4.62.3 diff --git a/src/serve.py b/src/serve.py index 3f6fdc0..62b80cb 100644 --- a/src/serve.py +++ b/src/serve.py @@ -1,7 +1,6 @@ -import json +import gzip import logging -import logging -import tempfile +import traceback from multiprocessing import Process from operator import itemgetter @@ -14,42 +13,56 @@ from pyinfra.config import CONFIG from pyinfra.consume import consume, ConsumerError from pyinfra.storage.azure_blob_storage import AzureBlobStorageHandle from pyinfra.storage.minio import MinioHandle -from pyinfra.utils.file import dossier_id_and_file_id_to_compressed_storage_pdf_object_name, download, unzip +from pyinfra.utils.file import dossier_id_and_file_id_to_compressed_storage_pdf_object_name -def make_file_getter(storage): - def get_file(payload, pdf_dir): - with tempfile.TemporaryDirectory() as pdf_compressed_dir: - dossier_id, file_id = itemgetter("dossierId", "fileId")(payload) - object_name = dossier_id_and_file_id_to_compressed_storage_pdf_object_name(dossier_id, file_id) - downloaded_file_path = download(storage, object_name, pdf_compressed_dir) - unzipped_file_path = unzip(downloaded_file_path, pdf_dir) +def make_storage_data_loader(storage): + def get_object_name(payload: dict) -> str: + dossier_id, file_id = itemgetter("dossierId", "fileId")(payload) + object_name = dossier_id_and_file_id_to_compressed_storage_pdf_object_name(dossier_id, file_id) + return object_name - return unzipped_file_path + def download(payload): + object_name = get_object_name(payload) + logging.debug(f"Downloading {object_name}...") + data = storage.get_object(object_name) + logging.debug(f"Downloaded {object_name}.") + return data - return get_file + def decompress(data): + return gzip.decompress(data) + + def load_data(payload): + return decompress(download(payload)) + + return load_data -def make_file_analyzer(analysis_endpoint): - def analyze_file(file_path): - predictions = requests.post(analysis_endpoint, data=open(file_path, "rb")) - predictions.raise_for_status() - predictions = predictions.json() - return predictions +def make_analyzer(analysis_endpoint): + def analyze(data): + try: + analysis_response = requests.post(analysis_endpoint, data=data) + analysis_response.raise_for_status() + analysis_response = analysis_response.json() + return analysis_response + except Exception as err: + logging.warning("Exception caught when calling analysis endpoint.") + logging.warning(err) + logging.exception(traceback.format_exc()) + raise err - return analyze_file + return analyze def make_payload_processor(analysis_endpoint): - get_file = make_file_getter(get_storage()) - analyze_file = make_file_analyzer(analysis_endpoint) + load_data = make_storage_data_loader(get_storage()) + analyze_file = make_analyzer(analysis_endpoint) def process(payload: dict): logging.info(f"Processing {payload}...") - with tempfile.TemporaryDirectory() as pdf_dir: - file_path = get_file(payload, pdf_dir) - predictions = analyze_file(file_path) + data = load_data(payload) + predictions = analyze_file(data) return predictions return process @@ -102,6 +115,7 @@ def main(): ) webserver = Process(target=start_integrity_checks_webserver, args=("debug",)) + logging.info("Starting webserver...") webserver.start() try: