refactor: download and upload file logic, module structure, remove redundant files so far
This commit is contained in:
parent
ec5ad09fa8
commit
6802bf5960
@ -1,131 +0,0 @@
|
||||
import os
|
||||
from os import environ
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from dynaconf import Dynaconf
|
||||
|
||||
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
|
||||
|
||||
|
||||
def read_from_environment(environment_variable_name, default_value):
|
||||
return environ.get(environment_variable_name, default_value)
|
||||
|
||||
|
||||
def normalize_bool(value: Union[str, bool]):
|
||||
return value if isinstance(value, bool) else value in ["True", "true"]
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
# Logging level for service logger
|
||||
self.logging_level_root = read_from_environment("LOGGING_LEVEL_ROOT", "DEBUG")
|
||||
|
||||
# Enables Prometheus monitoring
|
||||
self.monitoring_enabled = normalize_bool(read_from_environment("MONITORING_ENABLED", True))
|
||||
|
||||
# Prometheus metric prefix, per convention '{product_name}_{service_name}_{parameter}'
|
||||
# In the current implementation, the results of a service define the parameter that is monitored,
|
||||
# i.e. analysis result per image means processing time per image is monitored.
|
||||
# TODO: add validator since some characters like '-' are not allowed by python prometheus
|
||||
self.prometheus_metric_prefix = read_from_environment(
|
||||
"PROMETHEUS_METRIC_PREFIX", "redactmanager_research_service_parameter"
|
||||
)
|
||||
|
||||
# Prometheus webserver address and port
|
||||
self.prometheus_host = "0.0.0.0"
|
||||
self.prometheus_port = 8080
|
||||
|
||||
# RabbitMQ host address
|
||||
self.rabbitmq_host = read_from_environment("RABBITMQ_HOST", "localhost")
|
||||
|
||||
# RabbitMQ host port
|
||||
self.rabbitmq_port = read_from_environment("RABBITMQ_PORT", "5672")
|
||||
|
||||
# RabbitMQ username
|
||||
self.rabbitmq_username = read_from_environment("RABBITMQ_USERNAME", "user")
|
||||
|
||||
# RabbitMQ password
|
||||
self.rabbitmq_password = read_from_environment("RABBITMQ_PASSWORD", "bitnami")
|
||||
|
||||
# Controls AMQP heartbeat timeout in seconds
|
||||
self.rabbitmq_heartbeat = int(read_from_environment("RABBITMQ_HEARTBEAT", 1))
|
||||
|
||||
# Controls AMQP connection sleep timer in seconds
|
||||
# important for heartbeat to come through while main function runs on other thread
|
||||
self.rabbitmq_connection_sleep = int(read_from_environment("RABBITMQ_CONNECTION_SLEEP", 5))
|
||||
|
||||
# Queue name for requests to the service
|
||||
self.request_queue = read_from_environment("REQUEST_QUEUE", "request_queue")
|
||||
|
||||
# Queue name for responses by service
|
||||
self.response_queue = read_from_environment("RESPONSE_QUEUE", "response_queue")
|
||||
|
||||
# Queue name for failed messages
|
||||
self.dead_letter_queue = read_from_environment("DEAD_LETTER_QUEUE", "dead_letter_queue")
|
||||
|
||||
# The type of storage to use {s3, azure}
|
||||
self.storage_backend = read_from_environment("STORAGE_BACKEND", "s3")
|
||||
|
||||
# The bucket / container to pull files specified in queue requests from
|
||||
if self.storage_backend == "s3":
|
||||
self.storage_bucket = read_from_environment("STORAGE_BUCKET_NAME", "redaction")
|
||||
else:
|
||||
self.storage_bucket = read_from_environment("STORAGE_AZURECONTAINERNAME", "redaction")
|
||||
|
||||
# S3 connection security flag and endpoint
|
||||
storage_address = read_from_environment("STORAGE_ENDPOINT", "http://127.0.0.1:9000")
|
||||
self.storage_secure_connection, self.storage_endpoint = validate_and_parse_s3_endpoint(storage_address)
|
||||
|
||||
# User for s3 storage
|
||||
self.storage_key = read_from_environment("STORAGE_KEY", "root")
|
||||
|
||||
# Password for s3 storage
|
||||
self.storage_secret = read_from_environment("STORAGE_SECRET", "password")
|
||||
|
||||
# Region for s3 storage
|
||||
self.storage_region = read_from_environment("STORAGE_REGION", "eu-central-1")
|
||||
|
||||
# Connection string for Azure storage
|
||||
self.storage_azureconnectionstring = read_from_environment(
|
||||
"STORAGE_AZURECONNECTIONSTRING",
|
||||
"DefaultEndpointsProtocol=...",
|
||||
)
|
||||
|
||||
# Allowed file types for downloaded and uploaded storage objects that get processed by the service
|
||||
self.allowed_file_types = ["json", "pdf"]
|
||||
self.allowed_compression_types = ["gz"]
|
||||
|
||||
self.allowed_processing_parameters = ["operation"]
|
||||
|
||||
# config for x-tenant-endpoint to receive storage connection information per tenant
|
||||
self.tenant_decryption_public_key = read_from_environment("TENANT_PUBLIC_KEY", "redaction")
|
||||
self.tenant_endpoint = read_from_environment(
|
||||
"TENANT_ENDPOINT", "http://tenant-user-management:8081/internal-api/tenants"
|
||||
)
|
||||
|
||||
# Value to see if we should write a consumer token to a file
|
||||
self.write_consumer_token = read_from_environment("WRITE_CONSUMER_TOKEN", "False")
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
return Config()
|
||||
|
||||
|
||||
def load_settings():
|
||||
# TODO: Make dynamic, so that the settings.toml file can be loaded from any location
|
||||
# TODO: add validation
|
||||
root_path = Path(__file__).resolve().parents[0] # this is pyinfra/
|
||||
repo_root_path = root_path.parents[0] # this is the root of the repo
|
||||
os.environ["ROOT_PATH"] = str(root_path)
|
||||
os.environ["REPO_ROOT_PATH"] = str(repo_root_path)
|
||||
|
||||
settings = Dynaconf(
|
||||
load_dotenv=True,
|
||||
envvar_prefix=False,
|
||||
settings_files=[
|
||||
repo_root_path / "config" / "settings.toml",
|
||||
],
|
||||
)
|
||||
|
||||
return settings
|
||||
0
pyinfra/config/__init__.py
Normal file
0
pyinfra/config/__init__.py
Normal file
23
pyinfra/config/loader.py
Normal file
23
pyinfra/config/loader.py
Normal file
@ -0,0 +1,23 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from dynaconf import Dynaconf
|
||||
|
||||
|
||||
def load_settings():
|
||||
# TODO: Make dynamic, so that the settings.toml file can be loaded from any location
|
||||
# TODO: add validation
|
||||
root_path = Path(__file__).resolve().parents[1] # this is pyinfra/
|
||||
repo_root_path = root_path.parents[0] # this is the root of the repo
|
||||
os.environ["ROOT_PATH"] = str(root_path)
|
||||
os.environ["REPO_ROOT_PATH"] = str(repo_root_path)
|
||||
|
||||
settings = Dynaconf(
|
||||
load_dotenv=True,
|
||||
envvar_prefix=False,
|
||||
settings_files=[
|
||||
repo_root_path / "config" / "settings.toml",
|
||||
],
|
||||
)
|
||||
|
||||
return settings
|
||||
@ -1,5 +0,0 @@
|
||||
class ProcessingFailure(RuntimeError):
|
||||
pass
|
||||
|
||||
class UnknownStorageBackend(Exception):
|
||||
pass
|
||||
@ -1,5 +1,5 @@
|
||||
from time import time
|
||||
from typing import Sized, Callable, TypeVar
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
from dynaconf import Dynaconf
|
||||
from fastapi import FastAPI
|
||||
@ -7,7 +7,7 @@ from funcy import identity
|
||||
from prometheus_client import generate_latest, CollectorRegistry, REGISTRY, Summary
|
||||
from starlette.responses import Response
|
||||
|
||||
from pyinfra.utils.config_validation import validate_settings, prometheus_validators
|
||||
from pyinfra.config.validation import validate_settings, prometheus_validators
|
||||
|
||||
|
||||
def add_prometheus_endpoint(app: FastAPI, registry: CollectorRegistry = REGISTRY) -> FastAPI:
|
||||
|
||||
@ -13,7 +13,7 @@ from kn_utils.logging import logger
|
||||
from pika.adapters.blocking_connection import BlockingChannel, BlockingConnection
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.utils.config_validation import validate_settings, queue_manager_validators
|
||||
from pyinfra.config.validation import validate_settings, queue_manager_validators
|
||||
|
||||
pika_logger = logging.getLogger("pika")
|
||||
pika_logger.setLevel(logging.WARNING) # disables non-informative pika log clutter
|
||||
|
||||
@ -1,18 +1,14 @@
|
||||
from functools import lru_cache, partial
|
||||
from typing import Callable
|
||||
from functools import lru_cache
|
||||
|
||||
import requests
|
||||
from dynaconf import Dynaconf
|
||||
from funcy import compose
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.storage.storages.azure import get_azure_storage_from_settings
|
||||
from pyinfra.storage.storages.interface import Storage
|
||||
from pyinfra.storage.storages.s3 import get_s3_storage_from_settings
|
||||
from pyinfra.storage.storages.storage import Storage
|
||||
from pyinfra.utils.cipher import decrypt
|
||||
from pyinfra.utils.compressing import get_decompressor, get_compressor
|
||||
from pyinfra.utils.config_validation import validate_settings, storage_validators, multi_tenant_storage_validators
|
||||
from pyinfra.utils.encoding import get_decoder, get_encoder
|
||||
from pyinfra.config.validation import validate_settings, storage_validators, multi_tenant_storage_validators
|
||||
|
||||
|
||||
def get_storage(settings: Dynaconf, tenant_id: str = None) -> Storage:
|
||||
@ -55,7 +51,7 @@ def get_storage_from_tenant_id(tenant_id: str, settings: Dynaconf) -> Storage:
|
||||
if maybe_azure:
|
||||
connection_string = decrypt(public_key, maybe_azure["connectionString"])
|
||||
backend = "azure"
|
||||
storage_settings = {
|
||||
storage_info = {
|
||||
"storage": {
|
||||
"azure": {
|
||||
"connection_string": connection_string,
|
||||
@ -66,7 +62,7 @@ def get_storage_from_tenant_id(tenant_id: str, settings: Dynaconf) -> Storage:
|
||||
elif maybe_s3:
|
||||
secret = decrypt(public_key, maybe_s3["secret"])
|
||||
backend = "s3"
|
||||
storage_settings = {
|
||||
storage_info = {
|
||||
"storage": {
|
||||
"s3": {
|
||||
"endpoint": maybe_s3["endpoint"],
|
||||
@ -81,7 +77,7 @@ def get_storage_from_tenant_id(tenant_id: str, settings: Dynaconf) -> Storage:
|
||||
raise Exception(f"Unknown storage backend in {response}.")
|
||||
|
||||
storage_settings = Dynaconf()
|
||||
storage_settings.update(settings)
|
||||
storage_settings.update(storage_info)
|
||||
|
||||
storage = storage_dispatcher[backend](storage_settings)
|
||||
|
||||
@ -94,31 +90,3 @@ storage_dispatcher = {
|
||||
"azure": get_azure_storage_from_settings,
|
||||
"s3": get_s3_storage_from_settings,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def make_downloader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable:
|
||||
verify = partial(verify_existence, storage, bucket)
|
||||
download = partial(storage.get_object, bucket)
|
||||
decompress = get_decompressor(compression_type)
|
||||
decode = get_decoder(file_type)
|
||||
|
||||
return compose(decode, decompress, download, verify)
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def make_uploader(storage: Storage, bucket: str, file_type: str, compression_type: str) -> Callable:
|
||||
upload = partial(storage.put_object, bucket)
|
||||
compress = get_compressor(compression_type)
|
||||
encode = get_encoder(file_type)
|
||||
|
||||
def inner(file_name, file_bytes):
|
||||
upload(file_name, compose(compress, encode)(file_bytes))
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def verify_existence(storage: Storage, bucket: str, file_name: str) -> str:
|
||||
if not storage.exists(file_name):
|
||||
raise FileNotFoundError(f"{file_name=} name not found on storage in {storage.bucket=}.")
|
||||
return file_name
|
||||
|
||||
@ -7,8 +7,8 @@ from dynaconf import Dynaconf
|
||||
from kn_utils.logging import logger
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.storage.storages.interface import Storage
|
||||
from pyinfra.utils.config_validation import azure_storage_validators, validate_settings
|
||||
from pyinfra.storage.storages.storage import Storage
|
||||
from pyinfra.config.validation import azure_storage_validators, validate_settings
|
||||
|
||||
logging.getLogger("azure").setLevel(logging.WARNING)
|
||||
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||
|
||||
@ -1,39 +0,0 @@
|
||||
from pyinfra.storage.storages.interface import Storage
|
||||
|
||||
|
||||
class StorageMock(Storage):
|
||||
def __init__(self, data: bytes = None, file_name: str = None, bucket: str = None):
|
||||
self.data = data
|
||||
self.file_name = file_name
|
||||
self._bucket = bucket
|
||||
|
||||
@property
|
||||
def bucket(self):
|
||||
return self._bucket
|
||||
|
||||
def make_bucket(self):
|
||||
pass
|
||||
|
||||
def has_bucket(self):
|
||||
return True
|
||||
|
||||
def put_object(self, object_name, data):
|
||||
self.file_name = object_name
|
||||
self.data = data
|
||||
|
||||
def exists(self, object_name):
|
||||
return self.file_name == object_name
|
||||
|
||||
def get_object(self, object_name):
|
||||
return self.data
|
||||
|
||||
def get_all_objects(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def clear_bucket(self):
|
||||
self._bucket = None
|
||||
self.file_name = None
|
||||
self.data = None
|
||||
|
||||
def get_all_object_names(self):
|
||||
raise NotImplementedError
|
||||
@ -7,8 +7,8 @@ from kn_utils.logging import logger
|
||||
from minio import Minio
|
||||
from retry import retry
|
||||
|
||||
from pyinfra.storage.storages.interface import Storage
|
||||
from pyinfra.utils.config_validation import validate_settings, s3_storage_validators
|
||||
from pyinfra.storage.storages.storage import Storage
|
||||
from pyinfra.config.validation import validate_settings, s3_storage_validators
|
||||
from pyinfra.utils.url_parsing import validate_and_parse_s3_endpoint
|
||||
|
||||
|
||||
|
||||
106
pyinfra/storage/utils.py
Normal file
106
pyinfra/storage/utils.py
Normal file
@ -0,0 +1,106 @@
|
||||
import gzip
|
||||
import json
|
||||
from typing import Union
|
||||
|
||||
from kn_utils.logging import logger
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from pyinfra.storage.storages.storage import Storage
|
||||
|
||||
|
||||
class DossierIdFileIdDownloadPayload(BaseModel):
|
||||
dossierId: str
|
||||
fileId: str
|
||||
targetFileExtension: str
|
||||
|
||||
@property
|
||||
def targetFilePath(self):
|
||||
return f"{self.dossierId}/{self.fileId}.{self.targetFileExtension}"
|
||||
|
||||
|
||||
class DossierIdFileIdUploadPayload(BaseModel):
|
||||
dossierId: str
|
||||
fileId: str
|
||||
responseFileExtension: str
|
||||
|
||||
@property
|
||||
def responseFilePath(self):
|
||||
return f"{self.dossierId}/{self.fileId}.{self.responseFileExtension}"
|
||||
|
||||
|
||||
class TargetResponseFilePathDownloadPayload(BaseModel):
|
||||
targetFilePath: str
|
||||
|
||||
|
||||
class TargetResponseFilePathUploadPayload(BaseModel):
|
||||
responseFilePath: str
|
||||
|
||||
|
||||
def download_data_as_specified_in_message(storage: Storage, raw_payload: dict) -> Union[dict, bytes]:
|
||||
"""Convenience function to download a file specified in a message payload.
|
||||
Supports both legacy and new payload formats.
|
||||
|
||||
If the content is compressed with gzip (.gz), it will be decompressed (-> bytes).
|
||||
If the content is a json file, it will be decoded (-> dict).
|
||||
If no file is specified in the payload or the file does not exist in storage, an exception will be raised.
|
||||
In all other cases, the content will be returned as is (-> bytes).
|
||||
|
||||
This function can be extended in the future as needed (e.g. handling of more file types), but since further
|
||||
requirements are not specified at this point in time, and it is unclear what these would entail, the code is kept
|
||||
simple for now to improve readability, maintainability and avoid refactoring efforts of generic solutions that
|
||||
weren't as generic as they seemed.
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
if "dossierId" in raw_payload:
|
||||
payload = DossierIdFileIdDownloadPayload(**raw_payload)
|
||||
else:
|
||||
payload = TargetResponseFilePathDownloadPayload(**raw_payload)
|
||||
except ValidationError:
|
||||
raise ValueError("No download file path found in payload, nothing to download.")
|
||||
|
||||
if not storage.exists(payload.targetFilePath):
|
||||
raise FileNotFoundError(f"File '{payload.targetFilePath}' does not exist in storage.")
|
||||
|
||||
data = storage.get_object(payload.targetFilePath)
|
||||
|
||||
data = gzip.decompress(data) if ".gz" in payload.targetFilePath else data
|
||||
data = json.loads(data.decode("utf-8")) if ".json" in payload.targetFilePath else data
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def upload_data_as_specified_in_message(storage: Storage, raw_payload: dict, data):
|
||||
"""Convenience function to upload a file specified in a message payload. For now, only json-dump-able data is
|
||||
supported. The storage json consists of the raw_payload, which is extended with a 'data' key, containing the
|
||||
data to be uploaded.
|
||||
|
||||
If the content is not a json-dump-able object, an exception will be raised.
|
||||
If the result file identifier specifies compression with gzip (.gz), it will be compressed before upload.
|
||||
|
||||
This function can be extended in the future as needed (e.g. if we need to upload images), but since further
|
||||
requirements are not specified at this point in time, and it is unclear what these would entail, the code is kept
|
||||
simple for now to improve readability, maintainability and avoid refactoring efforts of generic solutions that
|
||||
weren't as generic as they seemed.
|
||||
"""
|
||||
|
||||
try:
|
||||
if "dossierId" in raw_payload:
|
||||
payload = DossierIdFileIdUploadPayload(**raw_payload)
|
||||
else:
|
||||
payload = TargetResponseFilePathUploadPayload(**raw_payload)
|
||||
except ValidationError:
|
||||
raise ValueError("No upload file path found in payload, nothing to upload.")
|
||||
|
||||
if ".json" not in payload.responseFilePath:
|
||||
raise ValueError("Only json-dump-able data can be uploaded.")
|
||||
|
||||
data = {**raw_payload, "data": data}
|
||||
|
||||
data = json.dumps(data).encode("utf-8")
|
||||
data = gzip.compress(data) if ".gz" in payload.responseFilePath else data
|
||||
|
||||
storage.put_object(payload.responseFilePath, data)
|
||||
|
||||
logger.info(f"Uploaded {payload.responseFilePath} to storage.")
|
||||
@ -1,22 +0,0 @@
|
||||
import gzip
|
||||
from typing import Union, Callable
|
||||
|
||||
from funcy import identity
|
||||
|
||||
|
||||
def get_decompressor(compression_type: Union[str, None]) -> Callable:
|
||||
if not compression_type:
|
||||
return identity
|
||||
elif "gz" in compression_type:
|
||||
return gzip.decompress
|
||||
else:
|
||||
raise ValueError(f"{compression_type=} is not supported.")
|
||||
|
||||
|
||||
def get_compressor(compression_type: str) -> Callable:
|
||||
if not compression_type:
|
||||
return identity
|
||||
elif "gz" in compression_type:
|
||||
return gzip.compress
|
||||
else:
|
||||
raise ValueError(f"{compression_type=} is not supported.")
|
||||
@ -1,5 +0,0 @@
|
||||
from funcy import project
|
||||
|
||||
|
||||
def safe_project(mapping, keys) -> dict:
|
||||
return project(mapping, keys) if mapping else {}
|
||||
@ -1,28 +0,0 @@
|
||||
import json
|
||||
from typing import Callable
|
||||
|
||||
from funcy import identity
|
||||
|
||||
|
||||
def decode_json(data: bytes) -> dict:
|
||||
return json.loads(data.decode("utf-8"))
|
||||
|
||||
|
||||
def encode_json(data: dict) -> bytes:
|
||||
return json.dumps(data).encode("utf-8")
|
||||
|
||||
|
||||
def get_decoder(file_type: str) -> Callable:
|
||||
if "json" in file_type:
|
||||
return decode_json
|
||||
elif "pdf" in file_type:
|
||||
return identity
|
||||
else:
|
||||
raise ValueError(f"{file_type=} is not supported.")
|
||||
|
||||
|
||||
def get_encoder(file_type: str) -> Callable:
|
||||
if "json" in file_type:
|
||||
return encode_json
|
||||
else:
|
||||
raise ValueError(f"{file_type=} is not supported.")
|
||||
@ -1,41 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from typing import Callable
|
||||
|
||||
from funcy import merge
|
||||
|
||||
|
||||
def make_file_extension_parser(file_types, compression_types):
|
||||
ext2_type2ext = make_ext2_type2ext(file_types, compression_types)
|
||||
ext_to_type2ext = make_ext_to_type2ext(ext2_type2ext)
|
||||
|
||||
def inner(path):
|
||||
file_extensions = parse_file_extensions(path, ext_to_type2ext)
|
||||
return file_extensions.get("file_type"), file_extensions.get("compression_type")
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def make_ext2_type2ext(file_type_extensions, compression_type_extensions):
|
||||
def make_ext_to_ext2type(ext_type):
|
||||
return lambda ext: {ext_type: ext}
|
||||
|
||||
ext_to_file_type_mapper = make_ext_to_ext2type("file_type")
|
||||
ext_to_compression_type_mapper = make_ext_to_ext2type("compression_type")
|
||||
return defaultdict(
|
||||
lambda: lambda _: {},
|
||||
{
|
||||
**{e: ext_to_file_type_mapper for e in file_type_extensions},
|
||||
**{e: ext_to_compression_type_mapper for e in compression_type_extensions},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def make_ext_to_type2ext(ext2_type2ext):
|
||||
def ext_to_type2ext(ext):
|
||||
return ext2_type2ext[ext](ext)
|
||||
|
||||
return ext_to_type2ext
|
||||
|
||||
|
||||
def parse_file_extensions(path, ext_to_type2ext: Callable):
|
||||
return merge(*map(ext_to_type2ext, path.split(".")))
|
||||
0
pyinfra/webserver/__init__.py
Normal file
0
pyinfra/webserver/__init__.py
Normal file
@ -5,7 +5,7 @@ import uvicorn
|
||||
from dynaconf import Dynaconf
|
||||
from fastapi import FastAPI
|
||||
|
||||
from pyinfra.utils.config_validation import validate_settings, webserver_validators
|
||||
from pyinfra.config.validation import validate_settings, webserver_validators
|
||||
|
||||
|
||||
def create_webserver_thread_from_settings(app: FastAPI, settings: Dynaconf) -> threading.Thread:
|
||||
@ -36,7 +36,6 @@ requests = "^2.31"
|
||||
minversion = "6.0"
|
||||
addopts = "-ra -q"
|
||||
testpaths = ["tests", "integration"]
|
||||
norecursedirs = "tests/tests_with_docker_compose"
|
||||
log_cli = 1
|
||||
log_cli_level = "DEBUG"
|
||||
|
||||
|
||||
@ -1,148 +1,8 @@
|
||||
import gzip
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from pyinfra.config import get_config, load_settings
|
||||
from pyinfra.payload_processing.payload import LegacyQueueMessagePayload, QueueMessagePayload
|
||||
from pyinfra.config.loader import load_settings
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def settings():
|
||||
return load_settings()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def legacy_payload(x_tenant_id, optional_processing_kwargs):
|
||||
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
|
||||
optional_processing_kwargs = optional_processing_kwargs or {}
|
||||
return {
|
||||
"dossierId": "test",
|
||||
"fileId": "test",
|
||||
"targetFileExtension": "target.json.gz",
|
||||
"responseFileExtension": "response.json.gz",
|
||||
**x_tenant_entry,
|
||||
**optional_processing_kwargs,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_file_path():
|
||||
return "test/test.target.json.gz"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def response_file_path():
|
||||
return "test/test.response.json.gz"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload(x_tenant_id, optional_processing_kwargs, target_file_path, response_file_path):
|
||||
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
|
||||
optional_processing_kwargs = optional_processing_kwargs or {}
|
||||
return {
|
||||
"targetFilePath": target_file_path,
|
||||
"responseFilePath": response_file_path,
|
||||
**x_tenant_entry,
|
||||
**optional_processing_kwargs,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def legacy_queue_response_payload(x_tenant_id, optional_processing_kwargs):
|
||||
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
|
||||
optional_processing_kwargs = optional_processing_kwargs or {}
|
||||
return {
|
||||
"dossierId": "test",
|
||||
"fileId": "test",
|
||||
**x_tenant_entry,
|
||||
**optional_processing_kwargs,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue_response_payload(x_tenant_id, optional_processing_kwargs, target_file_path, response_file_path):
|
||||
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
|
||||
optional_processing_kwargs = optional_processing_kwargs or {}
|
||||
return {
|
||||
"targetFilePath": target_file_path,
|
||||
"responseFilePath": response_file_path,
|
||||
**x_tenant_entry,
|
||||
**optional_processing_kwargs,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def legacy_storage_payload(x_tenant_id, optional_processing_kwargs, processing_result_json):
|
||||
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
|
||||
optional_processing_kwargs = optional_processing_kwargs or {}
|
||||
return {
|
||||
"dossierId": "test",
|
||||
"fileId": "test",
|
||||
"targetFileExtension": "target.json.gz",
|
||||
"responseFileExtension": "response.json.gz",
|
||||
**x_tenant_entry,
|
||||
**optional_processing_kwargs,
|
||||
"data": processing_result_json,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_payload(x_tenant_id, optional_processing_kwargs, processing_result_json, target_file_path, response_file_path):
|
||||
x_tenant_entry = {"X-TENANT-ID": x_tenant_id} if x_tenant_id else {}
|
||||
optional_processing_kwargs = optional_processing_kwargs or {}
|
||||
return {
|
||||
"targetFilePath": target_file_path,
|
||||
"responseFilePath": response_file_path,
|
||||
**x_tenant_entry,
|
||||
**optional_processing_kwargs,
|
||||
"data": processing_result_json,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def legacy_parsed_payload(
|
||||
x_tenant_id, optional_processing_kwargs, target_file_path, response_file_path
|
||||
) -> LegacyQueueMessagePayload:
|
||||
return LegacyQueueMessagePayload(
|
||||
dossier_id="test",
|
||||
file_id="test",
|
||||
x_tenant_id=x_tenant_id,
|
||||
target_file_extension="target.json.gz",
|
||||
response_file_extension="response.json.gz",
|
||||
target_file_type="json",
|
||||
target_compression_type="gz",
|
||||
response_file_type="json",
|
||||
response_compression_type="gz",
|
||||
target_file_path=target_file_path,
|
||||
response_file_path=response_file_path,
|
||||
processing_kwargs=optional_processing_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parsed_payload(
|
||||
x_tenant_id, optional_processing_kwargs, target_file_path, response_file_path
|
||||
) -> QueueMessagePayload:
|
||||
return QueueMessagePayload(
|
||||
x_tenant_id=x_tenant_id,
|
||||
target_file_type="json",
|
||||
target_compression_type="gz",
|
||||
response_file_type="json",
|
||||
response_compression_type="gz",
|
||||
target_file_path=target_file_path,
|
||||
response_file_path=response_file_path,
|
||||
processing_kwargs=optional_processing_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_json_file() -> bytes:
|
||||
data = {"target": "test"}
|
||||
enc_data = json.dumps(data).encode("utf-8")
|
||||
compr_data = gzip.compress(enc_data)
|
||||
return compr_data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def processing_result_json() -> dict:
|
||||
return {"response": "test"}
|
||||
|
||||
@ -8,7 +8,7 @@ services:
|
||||
- MINIO_ROOT_PASSWORD=password
|
||||
- MINIO_ROOT_USER=root
|
||||
volumes:
|
||||
- ./data/minio_store:/data
|
||||
- /tmp/minio_store:/data
|
||||
command: server /data
|
||||
network_mode: "bridge"
|
||||
rabbitmq:
|
||||
|
||||
@ -6,7 +6,7 @@ import requests
|
||||
from fastapi import FastAPI
|
||||
|
||||
from pyinfra.monitor.prometheus import add_prometheus_endpoint, make_prometheus_processing_time_decorator_from_settings
|
||||
from pyinfra.webserver import create_webserver_thread_from_settings
|
||||
from pyinfra.webserver.utils import create_webserver_thread_from_settings
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
@ -1,11 +1,14 @@
|
||||
import gzip
|
||||
import json
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
from pyinfra.storage.connection import get_storage_from_settings, get_storage_from_tenant_id
|
||||
from pyinfra.storage.utils import download_data_as_specified_in_message, upload_data_as_specified_in_message
|
||||
from pyinfra.utils.cipher import encrypt
|
||||
from pyinfra.webserver import create_webserver_thread
|
||||
from pyinfra.webserver.utils import create_webserver_thread
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
@ -19,41 +22,6 @@ def storage(storage_backend, settings):
|
||||
storage.clear_bucket()
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def tenant_server_mock(settings, tenant_server_host, tenant_server_port):
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/azure_tenant")
|
||||
def get_azure_storage_info():
|
||||
return {
|
||||
"azureStorageConnection": {
|
||||
"connectionString": encrypt(
|
||||
settings.storage.tenant_server.public_key, settings.storage.azure.connection_string
|
||||
),
|
||||
"containerName": settings.storage.azure.container,
|
||||
}
|
||||
}
|
||||
|
||||
@app.get("/s3_tenant")
|
||||
def get_s3_storage_info():
|
||||
return {
|
||||
"s3StorageConnection": {
|
||||
"endpoint": settings.storage.s3.endpoint,
|
||||
"key": settings.storage.s3.key,
|
||||
"secret": encrypt(settings.storage.tenant_server.public_key, settings.storage.s3.secret),
|
||||
"region": settings.storage.s3.region,
|
||||
"bucketName": settings.storage.s3.bucket,
|
||||
}
|
||||
}
|
||||
|
||||
thread = create_webserver_thread(app, tenant_server_port, tenant_server_host)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
sleep(1)
|
||||
yield
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
|
||||
class TestStorage:
|
||||
def test_clearing_bucket_yields_empty_bucket(self, storage):
|
||||
@ -103,6 +71,41 @@ class TestStorage:
|
||||
storage.get_object("folder/file")
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def tenant_server_mock(settings, tenant_server_host, tenant_server_port):
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/azure_tenant")
|
||||
def get_azure_storage_info():
|
||||
return {
|
||||
"azureStorageConnection": {
|
||||
"connectionString": encrypt(
|
||||
settings.storage.tenant_server.public_key, settings.storage.azure.connection_string
|
||||
),
|
||||
"containerName": settings.storage.azure.container,
|
||||
}
|
||||
}
|
||||
|
||||
@app.get("/s3_tenant")
|
||||
def get_s3_storage_info():
|
||||
return {
|
||||
"s3StorageConnection": {
|
||||
"endpoint": settings.storage.s3.endpoint,
|
||||
"key": settings.storage.s3.key,
|
||||
"secret": encrypt(settings.storage.tenant_server.public_key, settings.storage.s3.secret),
|
||||
"region": settings.storage.s3.region,
|
||||
"bucketName": settings.storage.s3.bucket,
|
||||
}
|
||||
}
|
||||
|
||||
thread = create_webserver_thread(app, tenant_server_port, tenant_server_host)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
sleep(1)
|
||||
yield
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tenant_id", ["azure_tenant", "s3_tenant"], scope="class")
|
||||
@pytest.mark.parametrize("tenant_server_host", ["localhost"], scope="class")
|
||||
@pytest.mark.parametrize("tenant_server_port", [8000], scope="class")
|
||||
@ -117,3 +120,39 @@ class TestMultiTenantStorage:
|
||||
data_received = storage.get_object("file")
|
||||
|
||||
assert b"content" == data_received
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload(payload_type):
|
||||
if payload_type == "target_response_file_path":
|
||||
return {
|
||||
"targetFilePath": "test/file.target.json.gz",
|
||||
"responseFilePath": "test/file.response.json.gz",
|
||||
}
|
||||
elif payload_type == "dossier_id_file_id":
|
||||
return {
|
||||
"dossierId": "test",
|
||||
"fileId": "file",
|
||||
"targetFileExtension": "target.json.gz",
|
||||
"responseFileExtension": "response.json.gz",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("payload_type", ["target_response_file_path", "dossier_id_file_id"], scope="class")
|
||||
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
|
||||
class TestDownloadAndUploadFromMessage:
|
||||
def test_download_and_upload_from_message(self, storage, payload):
|
||||
storage.clear_bucket()
|
||||
|
||||
input_data = {"data": "success"}
|
||||
|
||||
storage.put_object("test/file.target.json.gz", gzip.compress(json.dumps(input_data).encode()))
|
||||
|
||||
data = download_data_as_specified_in_message(storage, payload)
|
||||
|
||||
assert data == input_data
|
||||
|
||||
upload_data_as_specified_in_message(storage, payload, input_data)
|
||||
data = json.loads(gzip.decompress(storage.get_object("test/file.response.json.gz")).decode())
|
||||
|
||||
assert data == {**payload, "data": input_data}
|
||||
@ -1,32 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from pyinfra.utils.file_extension_parsing import make_file_extension_parser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_extension_parser(file_types, compression_types):
|
||||
return make_file_extension_parser(file_types, compression_types)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"file_path,file_types,compression_types,expected_file_extension,expected_compression_extension",
|
||||
[
|
||||
("test.txt", ["txt"], ["gz"], "txt", None),
|
||||
("test.txt.gz", ["txt"], ["gz"], "txt", "gz"),
|
||||
("test.txt.gz", [], [], None, None),
|
||||
("test.txt.gz", ["txt"], [], "txt", None),
|
||||
("test.txt.gz", [], ["gz"], None, "gz"),
|
||||
("test", ["txt"], ["gz"], None, None),
|
||||
],
|
||||
)
|
||||
def test_file_extension_parsing(
|
||||
file_extension_parser,
|
||||
file_path,
|
||||
file_types,
|
||||
compression_types,
|
||||
expected_file_extension,
|
||||
expected_compression_extension,
|
||||
):
|
||||
file_extension, compression_extension = file_extension_parser(file_path)
|
||||
assert file_extension == expected_file_extension
|
||||
assert compression_extension == expected_compression_extension
|
||||
@ -1,44 +0,0 @@
|
||||
import re
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from pyinfra.payload_processing.monitor import PrometheusMonitor
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def monitored_mock_function(metric_prefix, host, port):
|
||||
def process(data=None):
|
||||
time.sleep(2)
|
||||
return ["result1", "result2", "result3"]
|
||||
|
||||
monitor = PrometheusMonitor(metric_prefix, host, port)
|
||||
return monitor(process)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metric_endpoint(host, port):
|
||||
return f"http://{host}:{port}/prometheus"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("metric_prefix, host, port", [("test", "0.0.0.0", 8000)], scope="class")
|
||||
class TestPrometheusMonitor:
|
||||
def test_prometheus_endpoint_is_available(self, metric_endpoint, monitored_mock_function):
|
||||
resp = requests.get(metric_endpoint)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_processing_with_a_monitored_fn_increases_parameter_counter(
|
||||
self,
|
||||
metric_endpoint,
|
||||
metric_prefix,
|
||||
monitored_mock_function,
|
||||
):
|
||||
monitored_mock_function(data=None)
|
||||
resp = requests.get(metric_endpoint)
|
||||
pattern = re.compile(rf".*{metric_prefix}_processing_time_count (\d\.\d).*")
|
||||
assert pattern.search(resp.text).group(1) == "1.0"
|
||||
|
||||
monitored_mock_function(data=None)
|
||||
resp = requests.get(metric_endpoint)
|
||||
assert pattern.search(resp.text).group(1) == "2.0"
|
||||
@ -1,48 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from pyinfra.config import get_config
|
||||
from pyinfra.payload_processing.payload import (
|
||||
get_queue_message_payload_parser,
|
||||
format_to_queue_message_response_body,
|
||||
format_service_processing_result_for_storage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload_parser():
|
||||
config = get_config()
|
||||
return get_queue_message_payload_parser(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x_tenant_id", [None, "klaus"])
|
||||
@pytest.mark.parametrize("optional_processing_kwargs", [{}, {"operation": "test"}])
|
||||
class TestPayloadParsing:
|
||||
def test_legacy_payload_parsing(self, payload_parser, legacy_payload, legacy_parsed_payload):
|
||||
parsed_payload = payload_parser(legacy_payload)
|
||||
assert parsed_payload == legacy_parsed_payload
|
||||
|
||||
def test_payload_parsing(self, payload_parser, payload, parsed_payload):
|
||||
parsed_payload = payload_parser(payload)
|
||||
assert parsed_payload == parsed_payload
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x_tenant_id", [None, "klaus"])
|
||||
@pytest.mark.parametrize("optional_processing_kwargs", [{}, {"operation": "test"}])
|
||||
class TestPayloadFormatting:
|
||||
def test_legacy_payload_formatting_for_response(self, legacy_parsed_payload, legacy_queue_response_payload):
|
||||
formatted_payload = format_to_queue_message_response_body(legacy_parsed_payload)
|
||||
assert formatted_payload == legacy_queue_response_payload
|
||||
|
||||
def test_payload_formatting_for_response(self, parsed_payload, queue_response_payload):
|
||||
formatted_payload = format_to_queue_message_response_body(parsed_payload)
|
||||
assert formatted_payload == queue_response_payload
|
||||
|
||||
def test_legacy_payload_formatting_for_storage(
|
||||
self, legacy_parsed_payload, processing_result_json, legacy_storage_payload
|
||||
):
|
||||
formatted_payload = format_service_processing_result_for_storage(legacy_parsed_payload, processing_result_json)
|
||||
assert formatted_payload == legacy_storage_payload
|
||||
|
||||
def test_payload_formatting_for_storage(self, parsed_payload, processing_result_json, storage_payload):
|
||||
formatted_payload = format_service_processing_result_for_storage(parsed_payload, processing_result_json)
|
||||
assert formatted_payload == storage_payload
|
||||
@ -1,81 +0,0 @@
|
||||
import gzip
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from pyinfra.config import get_config
|
||||
from pyinfra.payload_processing.payload import get_queue_message_payload_parser
|
||||
from pyinfra.payload_processing.processor import PayloadProcessor
|
||||
from pyinfra.storage.storage_info import StorageInfo
|
||||
from pyinfra.storage.storage_provider import StorageProviderMock
|
||||
from pyinfra.storage.storages.mock import StorageMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bucket_name():
|
||||
return "test_bucket"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_mock(target_json_file, target_file_path, bucket_name):
|
||||
storage = StorageMock(target_json_file, target_file_path, bucket_name)
|
||||
return storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_info_mock(bucket_name):
|
||||
return StorageInfo(bucket_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_processor_mock(processing_result_json):
|
||||
def inner(data, **kwargs):
|
||||
return processing_result_json
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload_processor(storage_mock, storage_info_mock, data_processor_mock):
|
||||
storage_provider = StorageProviderMock(storage_mock, storage_info_mock)
|
||||
payload_parser = get_queue_message_payload_parser(get_config())
|
||||
return PayloadProcessor(storage_provider, payload_parser, data_processor_mock)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x_tenant_id", [None, "klaus"])
|
||||
@pytest.mark.parametrize("optional_processing_kwargs", [{}, {"operation": "test"}])
|
||||
class TestPayloadProcessor:
|
||||
def test_payload_processor_yields_correct_response_and_uploads_result_for_legacy_message(
|
||||
self,
|
||||
payload_processor,
|
||||
storage_mock,
|
||||
bucket_name,
|
||||
response_file_path,
|
||||
legacy_payload,
|
||||
legacy_queue_response_payload,
|
||||
legacy_storage_payload,
|
||||
):
|
||||
response = payload_processor(legacy_payload)
|
||||
|
||||
assert response == legacy_queue_response_payload
|
||||
|
||||
data_stored = storage_mock.get_object(bucket_name, response_file_path)
|
||||
|
||||
assert json.loads(gzip.decompress(data_stored).decode()) == legacy_storage_payload
|
||||
|
||||
def test_payload_processor_yields_correct_response_and_uploads_result(
|
||||
self,
|
||||
payload_processor,
|
||||
storage_mock,
|
||||
bucket_name,
|
||||
response_file_path,
|
||||
payload,
|
||||
queue_response_payload,
|
||||
storage_payload,
|
||||
):
|
||||
response = payload_processor(payload)
|
||||
|
||||
assert response == queue_response_payload
|
||||
|
||||
data_stored = storage_mock.get_object(bucket_name, response_file_path)
|
||||
|
||||
assert json.loads(gzip.decompress(data_stored).decode()) == storage_payload
|
||||
Loading…
x
Reference in New Issue
Block a user