WIP: add callback factory and update example scripts
This commit is contained in:
parent
6802bf5960
commit
b7f860f36b
@ -7,7 +7,7 @@ from funcy import identity
|
||||
from prometheus_client import generate_latest, CollectorRegistry, REGISTRY, Summary
|
||||
from starlette.responses import Response
|
||||
|
||||
from pyinfra.config.validation import validate_settings, prometheus_validators
|
||||
from pyinfra.config.validation import prometheus_validators, validate_settings
|
||||
|
||||
|
||||
def add_prometheus_endpoint(app: FastAPI, registry: CollectorRegistry = REGISTRY) -> FastAPI:
|
||||
|
||||
@ -1,199 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import singledispatch, partial
|
||||
from funcy import project, complement
|
||||
from itertools import chain
|
||||
from operator import itemgetter
|
||||
from typing import Union, Sized, Callable, List
|
||||
|
||||
from pyinfra.config import Config
|
||||
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueueMessagePayload:
|
||||
"""Default one-to-one payload, where the message contains the absolute file paths for the target and response files,
|
||||
that have to be acquired from the storage."""
|
||||
|
||||
target_file_path: str
|
||||
response_file_path: str
|
||||
|
||||
target_file_type: Union[str, None]
|
||||
target_compression_type: Union[str, None]
|
||||
response_file_type: Union[str, None]
|
||||
response_compression_type: Union[str, None]
|
||||
|
||||
x_tenant_id: Union[str, None]
|
||||
|
||||
processing_kwargs: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class LegacyQueueMessagePayload(QueueMessagePayload):
|
||||
"""Legacy one-to-one payload, where the message contains the dossier and file ids, and the file extensions that have
|
||||
to be used to construct the absolute file paths for the target and response files, that have to be acquired from the
|
||||
storage."""
|
||||
|
||||
dossier_id: str
|
||||
file_id: str
|
||||
|
||||
target_file_extension: str
|
||||
response_file_extension: str
|
||||
|
||||
|
||||
class QueueMessagePayloadParser:
|
||||
def __init__(self, payload_matcher2parse_strategy: dict):
|
||||
self.payload_matcher2parse_strategy = payload_matcher2parse_strategy
|
||||
|
||||
def __call__(self, payload: dict) -> QueueMessagePayload:
|
||||
for payload_matcher, parse_strategy in self.payload_matcher2parse_strategy.items():
|
||||
if payload_matcher(payload):
|
||||
return parse_strategy(payload)
|
||||
|
||||
|
||||
def get_queue_message_payload_parser(config: Config) -> QueueMessagePayloadParser:
|
||||
file_extension_parser = make_file_extension_parser(config.allowed_file_types, config.allowed_compression_types)
|
||||
|
||||
payload_matcher2parse_strategy = get_payload_matcher2parse_strategy(
|
||||
file_extension_parser, config.allowed_processing_parameters
|
||||
)
|
||||
|
||||
return QueueMessagePayloadParser(payload_matcher2parse_strategy)
|
||||
|
||||
|
||||
def get_payload_matcher2parse_strategy(parse_file_extensions: Callable, allowed_processing_parameters: List[str]):
|
||||
return {
|
||||
is_legacy_payload: partial(
|
||||
parse_legacy_queue_message_payload,
|
||||
parse_file_extensions=parse_file_extensions,
|
||||
allowed_processing_parameters=allowed_processing_parameters,
|
||||
),
|
||||
complement(is_legacy_payload): partial(
|
||||
parse_queue_message_payload,
|
||||
parse_file_extensions=parse_file_extensions,
|
||||
allowed_processing_parameters=allowed_processing_parameters,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def is_legacy_payload(payload: dict) -> bool:
|
||||
return {"dossierId", "fileId", "targetFileExtension", "responseFileExtension"}.issubset(payload.keys())
|
||||
|
||||
|
||||
def parse_queue_message_payload(
|
||||
payload: dict,
|
||||
parse_file_extensions: Callable,
|
||||
allowed_processing_parameters: List[str],
|
||||
) -> QueueMessagePayload:
|
||||
target_file_path, response_file_path = itemgetter("targetFilePath", "responseFilePath")(payload)
|
||||
|
||||
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
|
||||
map(parse_file_extensions, [target_file_path, response_file_path])
|
||||
)
|
||||
|
||||
x_tenant_id = payload.get("X-TENANT-ID")
|
||||
|
||||
processing_kwargs = project(payload, allowed_processing_parameters)
|
||||
|
||||
return QueueMessagePayload(
|
||||
target_file_path=target_file_path,
|
||||
response_file_path=response_file_path,
|
||||
target_file_type=target_file_type,
|
||||
target_compression_type=target_compression_type,
|
||||
response_file_type=response_file_type,
|
||||
response_compression_type=response_compression_type,
|
||||
x_tenant_id=x_tenant_id,
|
||||
processing_kwargs=processing_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def parse_legacy_queue_message_payload(
|
||||
payload: dict,
|
||||
parse_file_extensions: Callable,
|
||||
allowed_processing_parameters: List[str],
|
||||
) -> LegacyQueueMessagePayload:
|
||||
dossier_id, file_id, target_file_extension, response_file_extension = itemgetter(
|
||||
"dossierId", "fileId", "targetFileExtension", "responseFileExtension"
|
||||
)(payload)
|
||||
|
||||
target_file_path = f"{dossier_id}/{file_id}.{target_file_extension}"
|
||||
response_file_path = f"{dossier_id}/{file_id}.{response_file_extension}"
|
||||
|
||||
target_file_type, target_compression_type, response_file_type, response_compression_type = chain.from_iterable(
|
||||
map(parse_file_extensions, [target_file_extension, response_file_extension])
|
||||
)
|
||||
|
||||
x_tenant_id = payload.get("X-TENANT-ID")
|
||||
|
||||
processing_kwargs = project(payload, allowed_processing_parameters)
|
||||
|
||||
return LegacyQueueMessagePayload(
|
||||
dossier_id=dossier_id,
|
||||
file_id=file_id,
|
||||
x_tenant_id=x_tenant_id,
|
||||
target_file_extension=target_file_extension,
|
||||
response_file_extension=response_file_extension,
|
||||
target_file_type=target_file_type,
|
||||
target_compression_type=target_compression_type,
|
||||
response_file_type=response_file_type,
|
||||
response_compression_type=response_compression_type,
|
||||
target_file_path=target_file_path,
|
||||
response_file_path=response_file_path,
|
||||
processing_kwargs=processing_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def format_service_processing_result_for_storage(payload: QueueMessagePayload, result: Sized) -> dict:
|
||||
raise NotImplementedError("Unsupported payload type")
|
||||
|
||||
|
||||
@format_service_processing_result_for_storage.register(LegacyQueueMessagePayload)
|
||||
def _(payload: LegacyQueueMessagePayload, result: Sized) -> dict:
|
||||
processing_kwargs = payload.processing_kwargs or {}
|
||||
x_tenant_id = {"X-TENANT-ID": payload.x_tenant_id} if payload.x_tenant_id else {}
|
||||
return {
|
||||
"dossierId": payload.dossier_id,
|
||||
"fileId": payload.file_id,
|
||||
"targetFileExtension": payload.target_file_extension,
|
||||
"responseFileExtension": payload.response_file_extension,
|
||||
**x_tenant_id,
|
||||
**processing_kwargs,
|
||||
"data": result,
|
||||
}
|
||||
|
||||
|
||||
@format_service_processing_result_for_storage.register(QueueMessagePayload)
|
||||
def _(payload: QueueMessagePayload, result: Sized) -> dict:
|
||||
processing_kwargs = payload.processing_kwargs or {}
|
||||
x_tenant_id = {"X-TENANT-ID": payload.x_tenant_id} if payload.x_tenant_id else {}
|
||||
return {
|
||||
"targetFilePath": payload.target_file_path,
|
||||
"responseFilePath": payload.response_file_path,
|
||||
**x_tenant_id,
|
||||
**processing_kwargs,
|
||||
"data": result,
|
||||
}
|
||||
|
||||
|
||||
@singledispatch
|
||||
def format_to_queue_message_response_body(queue_message_payload: QueueMessagePayload) -> dict:
|
||||
raise NotImplementedError("Unsupported payload type")
|
||||
|
||||
|
||||
@format_to_queue_message_response_body.register(LegacyQueueMessagePayload)
|
||||
def _(payload: LegacyQueueMessagePayload) -> dict:
|
||||
processing_kwargs = payload.processing_kwargs or {}
|
||||
x_tenant_id = {"X-TENANT-ID": payload.x_tenant_id} if payload.x_tenant_id else {}
|
||||
return {"dossierId": payload.dossier_id, "fileId": payload.file_id, **x_tenant_id, **processing_kwargs}
|
||||
|
||||
|
||||
@format_to_queue_message_response_body.register(QueueMessagePayload)
|
||||
def _(payload: QueueMessagePayload) -> dict:
|
||||
processing_kwargs = payload.processing_kwargs or {}
|
||||
x_tenant_id = {"X-TENANT-ID": payload.x_tenant_id} if payload.x_tenant_id else {}
|
||||
return {
|
||||
"targetFilePath": payload.target_file_path,
|
||||
"responseFilePath": payload.response_file_path,
|
||||
**x_tenant_id,
|
||||
**processing_kwargs,
|
||||
}
|
||||
@ -1,97 +0,0 @@
|
||||
from kn_utils.logging import logger
|
||||
from dataclasses import asdict
|
||||
from typing import Callable, List
|
||||
|
||||
from pyinfra.config import get_config, Config
|
||||
from pyinfra.payload_processing.monitor import get_monitor_from_config
|
||||
from pyinfra.payload_processing.payload import (
|
||||
QueueMessagePayloadParser,
|
||||
get_queue_message_payload_parser,
|
||||
format_service_processing_result_for_storage,
|
||||
format_to_queue_message_response_body,
|
||||
QueueMessagePayload,
|
||||
)
|
||||
from pyinfra.storage.storage import make_downloader, make_uploader
|
||||
from pyinfra.storage.storage_provider import StorageProvider
|
||||
|
||||
|
||||
class PayloadProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
storage_provider: StorageProvider,
|
||||
payload_parser: QueueMessagePayloadParser,
|
||||
data_processor: Callable,
|
||||
):
|
||||
"""Wraps an analysis function specified by a service (e.g. NER service) in pre- and post-processing steps.
|
||||
|
||||
Args:
|
||||
storage_provider: Storage manager that connects to the storage, using the tenant id if provided
|
||||
payload_parser: Parser that translates the queue message payload to the required QueueMessagePayload object
|
||||
data_processor: The analysis function to be called with the downloaded file
|
||||
NOTE: The result of the analysis function has to be an instance of `Sized`, e.g. a dict or a list to be
|
||||
able to upload it and to be able to monitor the processing time.
|
||||
"""
|
||||
self.parse_payload = payload_parser
|
||||
self.provide_storage = storage_provider
|
||||
self.process_data = data_processor
|
||||
|
||||
def __call__(self, queue_message_payload: dict) -> dict:
|
||||
"""Processes a queue message payload.
|
||||
|
||||
The steps executed are:
|
||||
1. Download the file specified in the message payload from the storage
|
||||
2. Process the file with the analysis function
|
||||
3. Upload the result to the storage
|
||||
4. Return the payload for a response queue message
|
||||
|
||||
Args:
|
||||
queue_message_payload: The payload of a queue message. The payload is expected to be a dict with the
|
||||
following keys:
|
||||
targetFilePath, responseFilePath
|
||||
OR
|
||||
dossierId, fileId, targetFileExtension, responseFileExtension
|
||||
|
||||
Returns:
|
||||
The payload for a response queue message, containing only the request payload.
|
||||
"""
|
||||
return self._process(queue_message_payload)
|
||||
|
||||
def _process(self, queue_message_payload: dict) -> dict:
|
||||
payload: QueueMessagePayload = self.parse_payload(queue_message_payload)
|
||||
|
||||
logger.info(f"Processing {payload.__class__.__name__} ...")
|
||||
logger.debug(f"Payload contents: {asdict(payload)} ...")
|
||||
|
||||
storage, storage_info = self.provide_storage(payload.x_tenant_id)
|
||||
|
||||
download_file_to_process = make_downloader(
|
||||
storage, storage_info.bucket_name, payload.target_file_type, payload.target_compression_type
|
||||
)
|
||||
upload_processing_result = make_uploader(
|
||||
storage, storage_info.bucket_name, payload.response_file_type, payload.response_compression_type
|
||||
)
|
||||
|
||||
data = download_file_to_process(payload.target_file_path)
|
||||
result: List[dict] = self.process_data(data, **payload.processing_kwargs)
|
||||
formatted_result = format_service_processing_result_for_storage(payload, result)
|
||||
|
||||
upload_processing_result(payload.response_file_path, formatted_result)
|
||||
|
||||
return format_to_queue_message_response_body(payload)
|
||||
|
||||
|
||||
def make_payload_processor(data_processor: Callable, config: Config = None) -> PayloadProcessor:
|
||||
"""Creates a payload processor."""
|
||||
config = config or get_config()
|
||||
|
||||
storage_provider = StorageProvider(config)
|
||||
monitor = get_monitor_from_config(config)
|
||||
payload_parser: QueueMessagePayloadParser = get_queue_message_payload_parser(config)
|
||||
|
||||
data_processor = monitor(data_processor)
|
||||
|
||||
return PayloadProcessor(
|
||||
storage_provider,
|
||||
payload_parser,
|
||||
data_processor,
|
||||
)
|
||||
36
pyinfra/queue/callback.py
Normal file
36
pyinfra/queue/callback.py
Normal file
@ -0,0 +1,36 @@
|
||||
from typing import Callable, Union
|
||||
|
||||
from dynaconf import Dynaconf
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.storage.connection import get_storage
|
||||
from pyinfra.storage.utils import download_data_as_specified_in_message, upload_data_as_specified_in_message
|
||||
|
||||
DataProcessor = Callable[[Union[dict, bytes], dict], dict]
|
||||
|
||||
|
||||
def make_payload_processor(data_processor: DataProcessor, settings: Dynaconf):
|
||||
"""Default callback for processing queue messages.
|
||||
Data will be downloaded from the storage as specified in the message. If a tenant id is specified, the storage
|
||||
will be configured to use that tenant id, otherwise the storage is configured as specified in the settings.
|
||||
The data is the passed to the dataprocessor, together with the message. The dataprocessor should return a
|
||||
json-dump-able object. This object is then uploaded to the storage as specified in the message.
|
||||
|
||||
The response message is just the original message.
|
||||
Adapt as needed.
|
||||
"""
|
||||
|
||||
def inner(queue_message_payload: dict) -> dict:
|
||||
logger.info(f"Processing payload...")
|
||||
|
||||
storage = get_storage(settings, queue_message_payload.get("X-TENANT-ID"))
|
||||
|
||||
data = download_data_as_specified_in_message(storage, queue_message_payload)
|
||||
|
||||
result = data_processor(data, queue_message_payload)
|
||||
|
||||
upload_data_as_specified_in_message(storage, queue_message_payload, result)
|
||||
|
||||
return queue_message_payload
|
||||
|
||||
return inner
|
||||
@ -13,7 +13,7 @@ from kn_utils.logging import logger
|
||||
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.config.validation import validate_settings, queue_manager_validators
|
||||
from pyinfra.config.validation import queue_manager_validators, validate_settings
|
||||
|
||||
pika_logger = logging.getLogger("pika")
|
||||
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
|
||||
@ -175,6 +175,7 @@ class QueueManager:
|
||||
result,
|
||||
properties=pika.BasicProperties(headers=filtered_message_headers),
|
||||
)
|
||||
# FIXME: publish doesnt work in example script, explore, adapt, overcome
|
||||
logger.info(f"Published result to queue {self.output_queue}.")
|
||||
|
||||
channel.basic_ack(delivery_tag=method.delivery_tag)
|
||||
|
||||
@ -8,7 +8,7 @@ from pyinfra.storage.storages.azure import get_azure_storage_from_settings
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
|
||||
from pyinfra.storage.storages.storage import Storage
|
||||
from pyinfra.utils.cipher import decrypt
|
||||
from pyinfra.config.validation import validate_settings, storage_validators, multi_tenant_storage_validators
|
||||
from pyinfra.config.validation import storage_validators, multi_tenant_storage_validators, validate_settings
|
||||
|
||||
|
||||
def get_storage(settings: Dynaconf, tenant_id: str = None) -> Storage:
|
||||
|
||||
@ -8,7 +8,7 @@ from minio import Minio
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.storage.storages.storage import Storage
|
||||
from pyinfra.config.validation import validate_settings, s3_storage_validators
|
||||
from pyinfra.config.validation import s3_storage_validators, validate_settings
|
||||
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
|
||||
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import uvicorn
|
||||
from dynaconf import Dynaconf
|
||||
from fastapi import FastAPI
|
||||
|
||||
from pyinfra.config.validation import validate_settings, webserver_validators
|
||||
from pyinfra.config.validation import webserver_validators, validate_settings
|
||||
|
||||
|
||||
def create_webserver_thread_from_settings(app: FastAPI, settings: Dynaconf) -> threading.Thread:
|
||||
|
||||
@ -1,22 +1,20 @@
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
|
||||
import pika
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.config import get_config
|
||||
from pyinfra.queue.development_queue_manager import DevelopmentQueueManager
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage_from_config
|
||||
from pyinfra.config.loader import load_settings
|
||||
from pyinfra.queue.manager import QueueManager
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
|
||||
|
||||
CONFIG = get_config()
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
settings = load_settings()
|
||||
|
||||
|
||||
def upload_json_and_make_message_body():
|
||||
bucket = CONFIG.storage_bucket
|
||||
bucket = settings.storage.s3.bucket
|
||||
|
||||
dossier_id, file_id, suffix = "dossier", "file", "json.gz"
|
||||
content = {
|
||||
"numberOfPages": 7,
|
||||
@ -26,10 +24,10 @@ def upload_json_and_make_message_body():
|
||||
object_name = f"{dossier_id}/{file_id}.{suffix}"
|
||||
data = gzip.compress(json.dumps(content).encode("utf-8"))
|
||||
|
||||
storage = get_s3_storage_from_config(CONFIG)
|
||||
if not storage.has_bucket(bucket):
|
||||
storage.make_bucket(bucket)
|
||||
storage.put_object(bucket, object_name, data)
|
||||
storage = get_s3_storage_from_settings(settings)
|
||||
if not storage.has_bucket():
|
||||
storage.make_bucket()
|
||||
storage.put_object(object_name, data)
|
||||
|
||||
message_body = {
|
||||
"dossierId": dossier_id,
|
||||
@ -41,31 +39,31 @@ def upload_json_and_make_message_body():
|
||||
|
||||
|
||||
def main():
|
||||
development_queue_manager = DevelopmentQueueManager(CONFIG)
|
||||
development_queue_manager.clear_queues()
|
||||
queue_manager = QueueManager(settings)
|
||||
queue_manager.purge_queues()
|
||||
|
||||
message = upload_json_and_make_message_body()
|
||||
|
||||
development_queue_manager.publish_request(message, pika.BasicProperties(headers={"X-TENANT-ID": "redaction"}))
|
||||
logger.info(f"Put {message} on {CONFIG.request_queue}")
|
||||
queue_manager.publish_message_to_input_queue(message)
|
||||
logger.info(f"Put {message} on {settings.rabbitmq.input_queue}.")
|
||||
|
||||
storage = get_s3_storage_from_config(CONFIG)
|
||||
for method_frame, properties, body in development_queue_manager._channel.consume(
|
||||
queue=CONFIG.response_queue, inactivity_timeout=15
|
||||
storage = get_s3_storage_from_settings(settings)
|
||||
for method_frame, properties, body in queue_manager.channel.consume(
|
||||
queue=settings.rabbitmq.output_queue, inactivity_timeout=15
|
||||
):
|
||||
if not body:
|
||||
break
|
||||
response = json.loads(body)
|
||||
logger.info(f"Received {response}")
|
||||
logger.info(f"Message headers: {properties.headers}")
|
||||
development_queue_manager._channel.basic_ack(method_frame.delivery_tag)
|
||||
queue_manager.channel.basic_ack(method_frame.delivery_tag)
|
||||
dossier_id, file_id = itemgetter("dossierId", "fileId")(response)
|
||||
suffix = message["responseFileExtension"]
|
||||
print(f"{dossier_id}/{file_id}.{suffix}")
|
||||
result = storage.get_object(CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{suffix}")
|
||||
result = storage.get_object(f"{dossier_id}/{file_id}.{suffix}")
|
||||
result = json.loads(gzip.decompress(result))
|
||||
logger.info(f"Contents of result on storage: {result}")
|
||||
development_queue_manager.close_channel()
|
||||
queue_manager.stop_consuming()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,24 +1,42 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from pyinfra.config import get_config
|
||||
from pyinfra.payload_processing.processor import make_payload_processor
|
||||
from pyinfra.queue.queue_manager import QueueManager
|
||||
from fastapi import FastAPI
|
||||
|
||||
from pyinfra.config.loader import load_settings
|
||||
from pyinfra.monitor.prometheus import make_prometheus_processing_time_decorator_from_settings, add_prometheus_endpoint
|
||||
from pyinfra.queue.callback import make_payload_processor
|
||||
from pyinfra.queue.manager import QueueManager
|
||||
from pyinfra.webserver.utils import create_webserver_thread_from_settings
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
settings = load_settings()
|
||||
|
||||
def json_processor_mock(data: dict):
|
||||
|
||||
@make_prometheus_processing_time_decorator_from_settings(settings)
|
||||
def json_processor_mock(_data: dict, _message: dict) -> dict:
|
||||
time.sleep(5)
|
||||
return [{"result1": "result1"}, {"result2": "result2"}]
|
||||
return {"result1": "result1"}
|
||||
|
||||
|
||||
def main():
|
||||
logger.info("Start consuming...")
|
||||
queue_manager = QueueManager(get_config())
|
||||
queue_manager.start_consuming(make_payload_processor(json_processor_mock))
|
||||
app = FastAPI()
|
||||
app = add_prometheus_endpoint(app)
|
||||
|
||||
queue_manager = QueueManager(settings)
|
||||
|
||||
@app.get("/ready")
|
||||
@app.get("/health")
|
||||
def check_health():
|
||||
return queue_manager.is_ready()
|
||||
|
||||
webserver_thread = create_webserver_thread_from_settings(app, settings)
|
||||
webserver_thread.start()
|
||||
callback = make_payload_processor(json_processor_mock, settings)
|
||||
queue_manager.start_consuming(callback)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,8 +1,20 @@
|
||||
import pytest
|
||||
|
||||
from pyinfra.config.loader import load_settings
|
||||
from pyinfra.storage.connection import get_storage_from_settings
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def settings():
|
||||
return load_settings()
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def storage(storage_backend, settings):
|
||||
settings.storage.backend = storage_backend
|
||||
|
||||
storage = get_storage_from_settings(settings)
|
||||
storage.make_bucket()
|
||||
|
||||
yield storage
|
||||
storage.clear_bucket()
|
||||
|
||||
@ -5,23 +5,12 @@ from time import sleep
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
from pyinfra.storage.connection import get_storage_from_settings, get_storage_from_tenant_id
|
||||
from pyinfra.storage.connection import get_storage_from_tenant_id
|
||||
from pyinfra.storage.utils import download_data_as_specified_in_message, upload_data_as_specified_in_message
|
||||
from pyinfra.utils.cipher import encrypt
|
||||
from pyinfra.webserver.utils import create_webserver_thread
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def storage(storage_backend, settings):
|
||||
settings.storage.backend = storage_backend
|
||||
|
||||
storage = get_storage_from_settings(settings)
|
||||
storage.make_bucket()
|
||||
|
||||
yield storage
|
||||
storage.clear_bucket()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
|
||||
class TestStorage:
|
||||
def test_clearing_bucket_yields_empty_bucket(self, storage):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user