refactor: download and upload file logic, module structure, remove redundant files so far

This commit is contained in:
Julius Unverfehrt 2024-01-18 15:45:28 +01:00
parent ec5ad09fa8
commit 6802bf5960
31 changed files with 221 additions and 702 deletions

View File

@ -1,131 +0,0 @@
import os
from os import environ
from pathlib import Path
from typing import Union
from dynaconf import Dynaconf
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
def read_from_environment(environment_variable_name, default_value):
return environ.get(environment_variable_name, default_value)
def normalize_bool(value: Union[str, bool]):
return value if isinstance(value, bool) else value in ["True", "true"]
class Config:
def __init__(self):
# Logging level for service logger
self.logging_level_root = read_from_environment("LOGGING_LEVEL_ROOT", "DEBUG")
# Enables Prometheus monitoring
self.monitoring_enabled = normalize_bool(read_from_environment("MONITORING_ENABLED", True))
# Prometheus metric prefix, per convention '{product_name}_{service_name}_{parameter}'
# In the current implementation, the results of a service define the parameter that is monitored,
# i.e. analysis result per image means processing time per image is monitored.
# TODO: add validator since some characters like '-' are not allowed by python prometheus
self.prometheus_metric_prefix = read_from_environment(
"PROMETHEUS_METRIC_PREFIX", "redactmanager_research_service_parameter"
)
# Prometheus webserver address and port
self.prometheus_host = "0.0.0.0"
self.prometheus_port = 8080
# RabbitMQ host address
self.rabbitmq_host = read_from_environment("RABBITMQ_HOST", "localhost")
# RabbitMQ host port
self.rabbitmq_port = read_from_environment("RABBITMQ_PORT", "5672")
# RabbitMQ username
self.rabbitmq_username = read_from_environment("RABBITMQ_USERNAME", "user")
# RabbitMQ password
self.rabbitmq_password = read_from_environment("RABBITMQ_PASSWORD", "bitnami")
# Controls AMQP heartbeat timeout in seconds
self.rabbitmq_heartbeat = int(read_from_environment("RABBITMQ_HEARTBEAT", 1))
# Controls AMQP connection sleep timer in seconds
# important for heartbeat to come through while main function runs on other thread
self.rabbitmq_connection_sleep = int(read_from_environment("RABBITMQ_CONNECTION_SLEEP", 5))
# Queue name for requests to the service
self.request_queue = read_from_environment("REQUEST_QUEUE", "request_queue")
# Queue name for responses by service
self.response_queue = read_from_environment("RESPONSE_QUEUE", "response_queue")
# Queue name for failed messages
self.dead_letter_queue = read_from_environment("DEAD_LETTER_QUEUE", "dead_letter_queue")
# The type of storage to use {s3, azure}
self.storage_backend = read_from_environment("STORAGE_BACKEND", "s3")
# The bucket / container to pull files specified in queue requests from
if self.storage_backend == "s3":
self.storage_bucket = read_from_environment("STORAGE_BUCKET_NAME", "redaction")
else:
self.storage_bucket = read_from_environment("STORAGE_AZURECONTAINERNAME", "redaction")
# S3 connection security flag and endpoint
storage_address = read_from_environment("STORAGE_ENDPOINT", "http://127.0.0.1:9000")
self.storage_secure_connection, self.storage_endpoint = validate_and_parse_s3_endpoint(storage_address)
# User for s3 storage
self.storage_key = read_from_environment("STORAGE_KEY", "root")
# Password for s3 storage
self.storage_secret = read_from_environment("STORAGE_SECRET", "password")
# Region for s3 storage
self.storage_region = read_from_environment("STORAGE_REGION", "eu-central-1")
# Connection string for Azure storage
self.storage_azureconnectionstring = read_from_environment(
"STORAGE_AZURECONNECTIONSTRING",
"DefaultEndpointsProtocol=...",
)
# Allowed file types for downloaded and uploaded storage objects that get processed by the service
self.allowed_file_types = ["json", "pdf"]
self.allowed_compression_types = ["gz"]
self.allowed_processing_parameters = ["operation"]
# config for x-tenant-endpoint to receive storage connection information per tenant
self.tenant_decryption_public_key = read_from_environment("TENANT_PUBLIC_KEY", "redaction")
self.tenant_endpoint = read_from_environment(
"TENANT_ENDPOINT", "http://tenant-user-management:8081/internal-api/tenants"
)
# Value to see if we should write a consumer token to a file
self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False")
def get_config() -> Config:
return Config()
def load_settings():
# TODO: Make dynamic, so that the settings.toml file can be loaded from any location
# TODO: add validation
root_path = Path(__file__).resolve().parents[0] # this is pyinfra/
repo_root_path = root_path.parents[0] # this is the root of the repo
os.environ["ROOT_PATH"] = str(root_path)
os.environ["REPO_ROOT_PATH"] = str(repo_root_path)
settings = Dynaconf(
load_dotenv=True,
envvar_prefix=False,
settings_files=[
repo_root_path / "config" / "settings.toml",
],
)
return settings

View File

23
pyinfra/config/loader.py Normal file
View File

@ -0,0 +1,23 @@
import os
from pathlib import Path
from dynaconf import Dynaconf
def load_settings():
# TODO: Make dynamic, so that the settings.toml file can be loaded from any location
# TODO: add validation
root_path = Path(__file__).resolve().parents[1] # this is pyinfra/
repo_root_path = root_path.parents[0] # this is the root of the repo
os.environ["ROOT_PATH"] = str(root_path)
os.environ["REPO_ROOT_PATH"] = str(repo_root_path)
settings = Dynaconf(
load_dotenv=True,
envvar_prefix=False,
settings_files=[
repo_root_path / "config" / "settings.toml",
],
)
return settings

View File

@ -1,5 +0,0 @@
class ProcessingFailure(RuntimeError):
pass
class UnknownStorageBackend(Exception):
pass

View File

@ -1,5 +1,5 @@
from time import time
from typing import Sized, Callable, TypeVar
from typing import Callable, TypeVar
from dynaconf import Dynaconf
from fastapi import FastAPI
@ -7,7 +7,7 @@ from funcy import identity
from prometheus_client import generate_latest, CollectorRegistry, REGISTRY, Summary
from starlette.responses import Response
from pyinfra.utils.config_validation import validate_settings, prometheus_validators
from pyinfra.config.validation import validate_settings, prometheus_validators
def add_prometheus_endpoint(app: FastAPI, registry: CollectorRegistry = REGISTRY) -> FastAPI:

View File

@ -13,7 +13,7 @@ from kn_utils.logging import logger
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
from retry import retry
from pyinfra.utils.config_validation import validate_settings, queue_manager_validators
from pyinfra.config.validation import validate_settings, queue_manager_validators
pika_logger = logging.getLogger("pika")
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter

View File

@ -1,18 +1,14 @@
from functools import lru_cache, partial
from typing import Callable
from functools import lru_cache
import requests
from dynaconf import Dynaconf
from funcy import compose
from kn_utils.logging import logger
from pyinfra.storage.storages.azure import get_azure_storage_from_settings
from pyinfra.storage.storages.interface import Storage
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
from pyinfra.storage.storages.storage import Storage
from pyinfra.utils.cipher import decrypt
from pyinfra.utils.compressing import get_decompressor, get_compressor
from pyinfra.utils.config_validation import validate_settings, storage_validators, multi_tenant_storage_validators
from pyinfra.utils.encoding import get_decoder, get_encoder
from pyinfra.config.validation import validate_settings, storage_validators, multi_tenant_storage_validators
def get_storage(settings: Dynaconf, tenant_id: str = None) -> Storage:
@ -55,7 +51,7 @@ def get_storage_from_tenant_id(tenant_id: str, settings: Dynaconf) -> Storage:
if maybe_azure:
connection_string = decrypt(public_key, maybe_azure["connectionString"])
backend = "azure"
storage_settings = {
storage_info = {
"storage": {
"azure": {
"connection_string": connection_string,
@ -66,7 +62,7 @@ def get_storage_from_tenant_id(tenant_id: str, settings: Dynaconf) -> Storage:
elif maybe_s3:
secret = decrypt(public_key, maybe_s3["secret"])
backend = "s3"
storage_settings = {
storage_info = {
"storage": {
"s3": {
"endpoint": maybe_s3["endpoint"],
@ -81,7 +77,7 @@ def get_storage_from_tenant_id(tenant_id: str, settings: Dynaconf) -> Storage:
raise Exception(f"Unknown storage backend in {response}.")
storage_settings = Dynaconf()
storage_settings.update(settings)
storage_settings.update(storage_info)
storage = storage_dispatcher[backend](storage_settings)
@ -94,31 +90,3 @@ storage_dispatcher = {
"azure": get_azure_storage_from_settings,
"s3": get_s3_storage_from_settings,
}
@lru_cache(maxsize=10)
def make_downloader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable:
verify = partial(verify_existence, storage, bucket)
download = partial(storage.get_object, bucket)
decompress = get_decompressor(compression_type)
decode = get_decoder(file_type)
return compose(decode, decompress, download, verify)
@lru_cache(maxsize=10)
def make_uploader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable:
upload = partial(storage.put_object, bucket)
compress = get_compressor(compression_type)
encode = get_encoder(file_type)
def inner(file_name, file_bytes):
upload(file_name, compose(compress, encode)(file_bytes))
return inner
def verify_existence(storage: Storage, bucket: str, file_name: str) -> str:
if not storage.exists(file_name):
raise FileNotFoundError(f"{file_name=} name not found on storage in {storage.bucket=}.")
return file_name

View File

@ -7,8 +7,8 @@ from dynaconf import Dynaconf
from kn_utils.logging import logger
from retry import retry
from pyinfra.storage.storages.interface import Storage
from pyinfra.utils.config_validation import azure_storage_validators, validate_settings
from pyinfra.storage.storages.storage import Storage
from pyinfra.config.validation import azure_storage_validators, validate_settings
logging.getLogger("azure").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)

View File

@ -1,39 +0,0 @@
from pyinfra.storage.storages.interface import Storage
class StorageMock(Storage):
def __init__(self, data: bytes = None, file_name: str = None, bucket: str = None):
self.data = data
self.file_name = file_name
self._bucket = bucket
@property
def bucket(self):
return self._bucket
def make_bucket(self):
pass
def has_bucket(self):
return True
def put_object(self, object_name, data):
self.file_name = object_name
self.data = data
def exists(self, object_name):
return self.file_name == object_name
def get_object(self, object_name):
return self.data
def get_all_objects(self):
raise NotImplementedError
def clear_bucket(self):
self._bucket = None
self.file_name = None
self.data = None
def get_all_object_names(self):
raise NotImplementedError

View File

@ -7,8 +7,8 @@ from kn_utils.logging import logger
from minio import Minio
from retry import retry
from pyinfra.storage.storages.interface import Storage
from pyinfra.utils.config_validation import validate_settings, s3_storage_validators
from pyinfra.storage.storages.storage import Storage
from pyinfra.config.validation import validate_settings, s3_storage_validators
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint

106
pyinfra/storage/utils.py Normal file
View File

@ -0,0 +1,106 @@
import gzip
import json
from typing import Union
from kn_utils.logging import logger
from pydantic import BaseModel, ValidationError
from pyinfra.storage.storages.storage import Storage
class DossierIdFileIdDownloadPayload(BaseModel):
dossierId: str
fileId: str
targetFileExtension: str
@property
def targetFilePath(self):
return f"{self.dossierId}/{self.fileId}.{self.targetFileExtension}"
class DossierIdFileIdUploadPayload(BaseModel):
dossierId: str
fileId: str
responseFileExtension: str
@property
def responseFilePath(self):
return f"{self.dossierId}/{self.fileId}.{self.responseFileExtension}"
class TargetResponseFilePathDownloadPayload(BaseModel):
targetFilePath: str
class TargetResponseFilePathUploadPayload(BaseModel):
responseFilePath: str
def download_data_as_specified_in_message(storage: Storage, raw_payload: dict) -> Union[dict, bytes]:
"""Convenience function to download a file specified in a message payload.
Supports both legacy and new payload formats.
If the content is compressed with gzip (.gz), it will be decompressed (-> bytes).
If the content is a json file, it will be decoded (-> dict).
If no file is specified in the payload or the file does not exist in storage, an exception will be raised.
In all other cases, the content will be returned as is (-> bytes).
This function can be extended in the future as needed (e.g. handling of more file types), but since further
requirements are not specified at this point in time, and it is unclear what these would entail, the code is kept
simple for now to improve readability, maintainability and avoid refactoring efforts of generic solutions that
weren't as generic as they seemed.
"""
try:
if "dossierId" in raw_payload:
payload = DossierIdFileIdDownloadPayload(**raw_payload)
else:
payload = TargetResponseFilePathDownloadPayload(**raw_payload)
except ValidationError:
raise ValueError("No download file path found in payload, nothing to download.")
if not storage.exists(payload.targetFilePath):
raise FileNotFoundError(f"File '{payload.targetFilePath}' does not exist in storage.")
data = storage.get_object(payload.targetFilePath)
data = gzip.decompress(data) if ".gz" in payload.targetFilePath else data
data = json.loads(data.decode("utf-8")) if ".json" in payload.targetFilePath else data
return data
def upload_data_as_specified_in_message(storage: Storage, raw_payload: dict, data):
"""Convenience function to upload a file specified in a message payload. For now, only json-dump-able data is
supported. The storage json consists of the raw_payload, which is extended with a 'data' key, containing the
data to be uploaded.
If the content is not a json-dump-able object, an exception will be raised.
If the result file identifier specifies compression with gzip (.gz), it will be compressed before upload.
This function can be extended in the future as needed (e.g. if we need to upload images), but since further
requirements are not specified at this point in time, and it is unclear what these would entail, the code is kept
simple for now to improve readability, maintainability and avoid refactoring efforts of generic solutions that
weren't as generic as they seemed.
"""
try:
if "dossierId" in raw_payload:
payload = DossierIdFileIdUploadPayload(**raw_payload)
else:
payload = TargetResponseFilePathUploadPayload(**raw_payload)
except ValidationError:
raise ValueError("No upload file path found in payload, nothing to upload.")
if ".json" not in payload.responseFilePath:
raise ValueError("Only json-dump-able data can be uploaded.")
data = {**raw_payload, "data": data}
data = json.dumps(data).encode("utf-8")
data = gzip.compress(data) if ".gz" in payload.responseFilePath else data
storage.put_object(payload.responseFilePath, data)
logger.info(f"Uploaded {payload.responseFilePath} to storage.")

View File

@ -1,22 +0,0 @@
import gzip
from typing import Union, Callable
from funcy import identity
def get_decompressor(compression_type: Union[str, None]) -> Callable:
if not compression_type:
return identity
elif "gz" in compression_type:
return gzip.decompress
else:
raise ValueError(f"{compression_type=} is not supported.")
def get_compressor(compression_type: str) -> Callable:
if not compression_type:
return identity
elif "gz" in compression_type:
return gzip.compress
else:
raise ValueError(f"{compression_type=} is not supported.")

View File

@ -1,5 +0,0 @@
from funcy import project
def safe_project(mapping, keys) -> dict:
return project(mapping, keys) if mapping else {}

View File

@ -1,28 +0,0 @@
import json
from typing import Callable
from funcy import identity
def decode_json(data: bytes) -> dict:
return json.loads(data.decode("utf-8"))
def encode_json(data: dict) -> bytes:
return json.dumps(data).encode("utf-8")
def get_decoder(file_type: str) -> Callable:
if "json" in file_type:
return decode_json
elif "pdf" in file_type:
return identity
else:
raise ValueError(f"{file_type=} is not supported.")
def get_encoder(file_type: str) -> Callable:
if "json" in file_type:
return encode_json
else:
raise ValueError(f"{file_type=} is not supported.")

View File

@ -1,41 +0,0 @@
from collections import defaultdict
from typing import Callable
from funcy import merge
def make_file_extension_parser(file_types, compression_types):
ext2_type2ext = make_ext2_type2ext(file_types, compression_types)
ext_to_type2ext = make_ext_to_type2ext(ext2_type2ext)
def inner(path):
file_extensions = parse_file_extensions(path, ext_to_type2ext)
return file_extensions.get("file_type"), file_extensions.get("compression_type")
return inner
def make_ext2_type2ext(file_type_extensions, compression_type_extensions):
def make_ext_to_ext2type(ext_type):
return lambda ext: {ext_type: ext}
ext_to_file_type_mapper = make_ext_to_ext2type("file_type")
ext_to_compression_type_mapper = make_ext_to_ext2type("compression_type")
return defaultdict(
lambda: lambda _: {},
{
**{e: ext_to_file_type_mapper for e in file_type_extensions},
**{e: ext_to_compression_type_mapper for e in compression_type_extensions},
},
)
def make_ext_to_type2ext(ext2_type2ext):
def ext_to_type2ext(ext):
return ext2_type2ext[ext](ext)
return ext_to_type2ext
def parse_file_extensions(path, ext_to_type2ext: Callable):
return merge(*map(ext_to_type2ext, path.split(".")))

View File

View File

@ -5,7 +5,7 @@ import uvicorn
from dynaconf import Dynaconf
from fastapi import FastAPI
from pyinfra.utils.config_validation import validate_settings, webserver_validators
from pyinfra.config.validation import validate_settings, webserver_validators
def create_webserver_thread_from_settings(app: FastAPI, settings: Dynaconf) -> threading.Thread:

View File

@ -36,7 +36,6 @@ requests = "^2.31"
minversion = "6.0"
addopts = "-ra -q"
testpaths = ["tests", "integration"]
norecursedirs = "tests/tests_with_docker_compose"
log_cli = 1
log_cli_level = "DEBUG"

View File

@ -1,148 +1,8 @@
import gzip
import json
import pytest
from pyinfra.config import get_config, load_settings
from pyinfra.payload_processing.payload import LegacyQueueMessagePayload, QueueMessagePayload
from pyinfra.config.loader import load_settings
@pytest.fixture(scope="session")
def settings():
return load_settings()
@pytest.fixture
def legacy_payload(x_tenant_id, optional_processing_kwargs):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
optional_processing_kwargs = optional_processing_kwargs or {}
return {
"dossierId": "test",
"fileId": "test",
"targetFileExtension": "target.json.gz",
"responseFileExtension": "response.json.gz",
**x_tenant_entry,
**optional_processing_kwargs,
}
@pytest.fixture
def target_file_path():
return "test/test.target.json.gz"
@pytest.fixture
def response_file_path():
return "test/test.response.json.gz"
@pytest.fixture
def payload(x_tenant_id, optional_processing_kwargs, target_file_path, response_file_path):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
optional_processing_kwargs = optional_processing_kwargs or {}
return {
"targetFilePath": target_file_path,
"responseFilePath": response_file_path,
**x_tenant_entry,
**optional_processing_kwargs,
}
@pytest.fixture
def legacy_queue_response_payload(x_tenant_id, optional_processing_kwargs):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
optional_processing_kwargs = optional_processing_kwargs or {}
return {
"dossierId": "test",
"fileId": "test",
**x_tenant_entry,
**optional_processing_kwargs,
}
@pytest.fixture
def queue_response_payload(x_tenant_id, optional_processing_kwargs, target_file_path, response_file_path):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
optional_processing_kwargs = optional_processing_kwargs or {}
return {
"targetFilePath": target_file_path,
"responseFilePath": response_file_path,
**x_tenant_entry,
**optional_processing_kwargs,
}
@pytest.fixture
def legacy_storage_payload(x_tenant_id, optional_processing_kwargs, processing_result_json):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
optional_processing_kwargs = optional_processing_kwargs or {}
return {
"dossierId": "test",
"fileId": "test",
"targetFileExtension": "target.json.gz",
"responseFileExtension": "response.json.gz",
**x_tenant_entry,
**optional_processing_kwargs,
"data": processing_result_json,
}
@pytest.fixture
def storage_payload(x_tenant_id, optional_processing_kwargs, processing_result_json, target_file_path, response_file_path):
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
optional_processing_kwargs = optional_processing_kwargs or {}
return {
"targetFilePath": target_file_path,
"responseFilePath": response_file_path,
**x_tenant_entry,
**optional_processing_kwargs,
"data": processing_result_json,
}
@pytest.fixture
def legacy_parsed_payload(
x_tenant_id, optional_processing_kwargs, target_file_path, response_file_path
) -> LegacyQueueMessagePayload:
return LegacyQueueMessagePayload(
dossier_id="test",
file_id="test",
x_tenant_id=x_tenant_id,
target_file_extension="target.json.gz",
response_file_extension="response.json.gz",
target_file_type="json",
target_compression_type="gz",
response_file_type="json",
response_compression_type="gz",
target_file_path=target_file_path,
response_file_path=response_file_path,
processing_kwargs=optional_processing_kwargs or {},
)
@pytest.fixture
def parsed_payload(
x_tenant_id, optional_processing_kwargs, target_file_path, response_file_path
) -> QueueMessagePayload:
return QueueMessagePayload(
x_tenant_id=x_tenant_id,
target_file_type="json",
target_compression_type="gz",
response_file_type="json",
response_compression_type="gz",
target_file_path=target_file_path,
response_file_path=response_file_path,
processing_kwargs=optional_processing_kwargs or {},
)
@pytest.fixture
def target_json_file() -> bytes:
data = {"target": "test"}
enc_data = json.dumps(data).encode("utf-8")
compr_data = gzip.compress(enc_data)
return compr_data
@pytest.fixture
def processing_result_json() -> dict:
return {"response": "test"}

View File

@ -8,7 +8,7 @@ services:
- MINIO_ROOT_PASSWORD=password
- MINIO_ROOT_USER=root
volumes:
- ./data/minio_store:/data
- /tmp/minio_store:/data
command: server /data
network_mode: "bridge"
rabbitmq:

View File

@ -6,7 +6,7 @@ import requests
from fastapi import FastAPI
from pyinfra.monitor.prometheus import add_prometheus_endpoint, make_prometheus_processing_time_decorator_from_settings
from pyinfra.webserver import create_webserver_thread_from_settings
from pyinfra.webserver.utils import create_webserver_thread_from_settings
@pytest.fixture(scope="class")

View File

@ -1,11 +1,14 @@
import gzip
import json
from time import sleep
import pytest
from fastapi import FastAPI
from pyinfra.storage.connection import get_storage_from_settings, get_storage_from_tenant_id
from pyinfra.storage.utils import download_data_as_specified_in_message, upload_data_as_specified_in_message
from pyinfra.utils.cipher import encrypt
from pyinfra.webserver import create_webserver_thread
from pyinfra.webserver.utils import create_webserver_thread
@pytest.fixture(scope="class")
@ -19,41 +22,6 @@ def storage(storage_backend, settings):
storage.clear_bucket()
@pytest.fixture(scope="class")
def tenant_server_mock(settings, tenant_server_host, tenant_server_port):
app = FastAPI()
@app.get("/azure_tenant")
def get_azure_storage_info():
return {
"azureStorageConnection": {
"connectionString": encrypt(
settings.storage.tenant_server.public_key, settings.storage.azure.connection_string
),
"containerName": settings.storage.azure.container,
}
}
@app.get("/s3_tenant")
def get_s3_storage_info():
return {
"s3StorageConnection": {
"endpoint": settings.storage.s3.endpoint,
"key": settings.storage.s3.key,
"secret": encrypt(settings.storage.tenant_server.public_key, settings.storage.s3.secret),
"region": settings.storage.s3.region,
"bucketName": settings.storage.s3.bucket,
}
}
thread = create_webserver_thread(app, tenant_server_port, tenant_server_host)
thread.daemon = True
thread.start()
sleep(1)
yield
thread.join(timeout=1)
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
class TestStorage:
def test_clearing_bucket_yields_empty_bucket(self, storage):
@ -103,6 +71,41 @@ class TestStorage:
storage.get_object("folder/file")
@pytest.fixture(scope="class")
def tenant_server_mock(settings, tenant_server_host, tenant_server_port):
app = FastAPI()
@app.get("/azure_tenant")
def get_azure_storage_info():
return {
"azureStorageConnection": {
"connectionString": encrypt(
settings.storage.tenant_server.public_key, settings.storage.azure.connection_string
),
"containerName": settings.storage.azure.container,
}
}
@app.get("/s3_tenant")
def get_s3_storage_info():
return {
"s3StorageConnection": {
"endpoint": settings.storage.s3.endpoint,
"key": settings.storage.s3.key,
"secret": encrypt(settings.storage.tenant_server.public_key, settings.storage.s3.secret),
"region": settings.storage.s3.region,
"bucketName": settings.storage.s3.bucket,
}
}
thread = create_webserver_thread(app, tenant_server_port, tenant_server_host)
thread.daemon = True
thread.start()
sleep(1)
yield
thread.join(timeout=1)
@pytest.mark.parametrize("tenant_id", ["azure_tenant", "s3_tenant"], scope="class")
@pytest.mark.parametrize("tenant_server_host", ["localhost"], scope="class")
@pytest.mark.parametrize("tenant_server_port", [8000], scope="class")
@ -117,3 +120,39 @@ class TestMultiTenantStorage:
data_received = storage.get_object("file")
assert b"content" == data_received
@pytest.fixture
def payload(payload_type):
if payload_type == "target_response_file_path":
return {
"targetFilePath": "test/file.target.json.gz",
"responseFilePath": "test/file.response.json.gz",
}
elif payload_type == "dossier_id_file_id":
return {
"dossierId": "test",
"fileId": "file",
"targetFileExtension": "target.json.gz",
"responseFileExtension": "response.json.gz",
}
@pytest.mark.parametrize("payload_type", ["target_response_file_path", "dossier_id_file_id"], scope="class")
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
class TestDownloadAndUploadFromMessage:
def test_download_and_upload_from_message(self, storage, payload):
storage.clear_bucket()
input_data = {"data": "success"}
storage.put_object("test/file.target.json.gz", gzip.compress(json.dumps(input_data).encode()))
data = download_data_as_specified_in_message(storage, payload)
assert data == input_data
upload_data_as_specified_in_message(storage, payload, input_data)
data = json.loads(gzip.decompress(storage.get_object("test/file.response.json.gz")).decode())
assert data == {**payload, "data": input_data}

View File

@ -1,32 +0,0 @@
import pytest
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
@pytest.fixture
def file_extension_parser(file_types, compression_types):
return make_file_extension_parser(file_types, compression_types)
@pytest.mark.parametrize(
"file_path,file_types,compression_types,expected_file_extension,expected_compression_extension",
[
("test.txt", ["txt"], ["gz"], "txt", None),
("test.txt.gz", ["txt"], ["gz"], "txt", "gz"),
("test.txt.gz", [], [], None, None),
("test.txt.gz", ["txt"], [], "txt", None),
("test.txt.gz", [], ["gz"], None, "gz"),
("test", ["txt"], ["gz"], None, None),
],
)
def test_file_extension_parsing(
file_extension_parser,
file_path,
file_types,
compression_types,
expected_file_extension,
expected_compression_extension,
):
file_extension, compression_extension = file_extension_parser(file_path)
assert file_extension == expected_file_extension
assert compression_extension == expected_compression_extension

View File

@ -1,44 +0,0 @@
import re
import time
import pytest
import requests
from pyinfra.payload_processing.monitor import PrometheusMonitor
@pytest.fixture(scope="class")
def monitored_mock_function(metric_prefix, host, port):
def process(data=None):
time.sleep(2)
return ["result1", "result2", "result3"]
monitor = PrometheusMonitor(metric_prefix, host, port)
return monitor(process)
@pytest.fixture
def metric_endpoint(host, port):
return f"http://{host}:{port}/prometheus"
@pytest.mark.parametrize("metric_prefix, host, port", [("test", "0.0.0.0", 8000)], scope="class")
class TestPrometheusMonitor:
def test_prometheus_endpoint_is_available(self, metric_endpoint, monitored_mock_function):
resp = requests.get(metric_endpoint)
assert resp.status_code == 200
def test_processing_with_a_monitored_fn_increases_parameter_counter(
self,
metric_endpoint,
metric_prefix,
monitored_mock_function,
):
monitored_mock_function(data=None)
resp = requests.get(metric_endpoint)
pattern = re.compile(rf".*{metric_prefix}_processing_time_count (\d\.\d).*")
assert pattern.search(resp.text).group(1) == "1.0"
monitored_mock_function(data=None)
resp = requests.get(metric_endpoint)
assert pattern.search(resp.text).group(1) == "2.0"

View File

@ -1,48 +0,0 @@
import pytest
from pyinfra.config import get_config
from pyinfra.payload_processing.payload import (
get_queue_message_payload_parser,
format_to_queue_message_response_body,
format_service_processing_result_for_storage,
)
@pytest.fixture
def payload_parser():
config = get_config()
return get_queue_message_payload_parser(config)
@pytest.mark.parametrize("x_tenant_id", [None, "klaus"])
@pytest.mark.parametrize("optional_processing_kwargs", [{}, {"operation": "test"}])
class TestPayloadParsing:
def test_legacy_payload_parsing(self, payload_parser, legacy_payload, legacy_parsed_payload):
parsed_payload = payload_parser(legacy_payload)
assert parsed_payload == legacy_parsed_payload
def test_payload_parsing(self, payload_parser, payload, parsed_payload):
parsed_payload = payload_parser(payload)
assert parsed_payload == parsed_payload
@pytest.mark.parametrize("x_tenant_id", [None, "klaus"])
@pytest.mark.parametrize("optional_processing_kwargs", [{}, {"operation": "test"}])
class TestPayloadFormatting:
def test_legacy_payload_formatting_for_response(self, legacy_parsed_payload, legacy_queue_response_payload):
formatted_payload = format_to_queue_message_response_body(legacy_parsed_payload)
assert formatted_payload == legacy_queue_response_payload
def test_payload_formatting_for_response(self, parsed_payload, queue_response_payload):
formatted_payload = format_to_queue_message_response_body(parsed_payload)
assert formatted_payload == queue_response_payload
def test_legacy_payload_formatting_for_storage(
self, legacy_parsed_payload, processing_result_json, legacy_storage_payload
):
formatted_payload = format_service_processing_result_for_storage(legacy_parsed_payload, processing_result_json)
assert formatted_payload == legacy_storage_payload
def test_payload_formatting_for_storage(self, parsed_payload, processing_result_json, storage_payload):
formatted_payload = format_service_processing_result_for_storage(parsed_payload, processing_result_json)
assert formatted_payload == storage_payload

View File

@ -1,81 +0,0 @@
import gzip
import json
import pytest
from pyinfra.config import get_config
from pyinfra.payload_processing.payload import get_queue_message_payload_parser
from pyinfra.payload_processing.processor import PayloadProcessor
from pyinfra.storage.storage_info import StorageInfo
from pyinfra.storage.storage_provider import StorageProviderMock
from pyinfra.storage.storages.mock import StorageMock
@pytest.fixture
def bucket_name():
return "test_bucket"
@pytest.fixture
def storage_mock(target_json_file, target_file_path, bucket_name):
storage = StorageMock(target_json_file, target_file_path, bucket_name)
return storage
@pytest.fixture
def storage_info_mock(bucket_name):
return StorageInfo(bucket_name)
@pytest.fixture
def data_processor_mock(processing_result_json):
def inner(data, **kwargs):
return processing_result_json
return inner
@pytest.fixture
def payload_processor(storage_mock, storage_info_mock, data_processor_mock):
storage_provider = StorageProviderMock(storage_mock, storage_info_mock)
payload_parser = get_queue_message_payload_parser(get_config())
return PayloadProcessor(storage_provider, payload_parser, data_processor_mock)
@pytest.mark.parametrize("x_tenant_id", [None, "klaus"])
@pytest.mark.parametrize("optional_processing_kwargs", [{}, {"operation": "test"}])
class TestPayloadProcessor:
def test_payload_processor_yields_correct_response_and_uploads_result_for_legacy_message(
self,
payload_processor,
storage_mock,
bucket_name,
response_file_path,
legacy_payload,
legacy_queue_response_payload,
legacy_storage_payload,
):
response = payload_processor(legacy_payload)
assert response == legacy_queue_response_payload
data_stored = storage_mock.get_object(bucket_name, response_file_path)
assert json.loads(gzip.decompress(data_stored).decode()) == legacy_storage_payload
def test_payload_processor_yields_correct_response_and_uploads_result(
self,
payload_processor,
storage_mock,
bucket_name,
response_file_path,
payload,
queue_response_payload,
storage_payload,
):
response = payload_processor(payload)
assert response == queue_response_payload
data_stored = storage_mock.get_object(bucket_name, response_file_path)
assert json.loads(gzip.decompress(data_stored).decode()) == storage_payload