feat:BREAKING CHANGE: download callback no forwards all files as bytes

This commit is contained in:
Julius Unverfehrt 2025-01-08 15:00:51 +01:00
parent 5ce66f18a0
commit 5c4400aa8b
6 changed files with 18 additions and 244 deletions

View File

@ -5,7 +5,7 @@ from kn_utils.logging import logger
from pyinfra.storage.connection import get_storage
from pyinfra.storage.utils import (
download_data_as_specified_in_message,
download_data_bytes_as_specified_in_message,
upload_data_as_specified_in_message,
)
@ -28,7 +28,7 @@ def make_download_process_upload_callback(data_processor: DataProcessor, setting
storage = get_storage(settings, queue_message_payload.get("X-TENANT-ID"))
data = download_data_as_specified_in_message(storage, queue_message_payload)
data: dict[str, bytes] | bytes = download_data_bytes_as_specified_in_message(storage, queue_message_payload)
result = data_processor(data, queue_message_payload)

View File

@ -1,7 +1,6 @@
import gzip
import json
from functools import singledispatch
from typing import Union
from kn_utils.logging import logger
from pydantic import BaseModel, ValidationError
@ -53,28 +52,19 @@ class TenantIdDossierIdFileIdUploadPayload(BaseModel):
class TargetResponseFilePathDownloadPayload(BaseModel):
targetFilePath: Union[str, dict]
targetFilePath: str | dict[str, str]
class TargetResponseFilePathUploadPayload(BaseModel):
responseFilePath: str
def download_data_as_specified_in_message(storage: Storage, raw_payload: dict) -> Union[dict, bytes]:
def download_data_bytes_as_specified_in_message(storage: Storage, raw_payload: dict) -> dict[str, bytes] | bytes:
"""Convenience function to download a file specified in a message payload.
Supports both legacy and new payload formats. Also supports downloading multiple files at once, which should
be specified in a dictionary under the 'targetFilePath' key with the file path as value.
If the content is compressed with gzip (.gz), it will be decompressed (-> bytes).
If the content is a json file, it will be decoded (-> dict).
If no file is specified in the payload or the file does not exist in storage, an exception will be raised.
In all other cases, the content will be returned as is (-> bytes).
This function can be extended in the future as needed (e.g. handling of more file types), but since further
requirements are not specified at this point in time, and it is unclear what these would entail, the code is kept
simple for now to improve readability, maintainability and avoid refactoring efforts of generic solutions that
weren't as generic as they seemed.
In all cases, the content will be returned as is (-> bytes).
"""
try:
@ -93,7 +83,7 @@ def download_data_as_specified_in_message(storage: Storage, raw_payload: dict) -
@singledispatch
def _download(file_path_or_file_path_dict: Union[str, dict], storage: Storage) -> Union[dict, bytes]:
def _download(file_path_or_file_path_dict: str | dict[str, str], storage: Storage) -> dict[str, bytes] | bytes:
pass
@ -103,23 +93,13 @@ def _download_single_file(file_path: str, storage: Storage) -> bytes:
raise FileNotFoundError(f"File '{file_path}' does not exist in storage.")
data = storage.get_object(file_path)
data = gzip.decompress(data) if ".gz" in file_path else data
if ".json" in file_path:
data = json.loads(data.decode("utf-8"))
elif ".proto" in file_path:
data = ProtoDataLoader()(file_path, data)
else:
pass # identity for other file types
logger.info(f"Downloaded {file_path} from storage.")
return data
@_download.register(dict)
def _download_multiple_files(file_path_dict: dict, storage: Storage) -> dict:
def _download_multiple_files(file_path_dict: dict, storage: Storage) -> dict[str, bytes]:
return {key: _download(value, storage) for key, value in file_path_dict.items()}

View File

@ -27,6 +27,8 @@ def storage(storage_backend, settings):
def queue_manager(settings):
settings.rabbitmq_heartbeat = 10
settings.connection_sleep = 5
settings.rabbitmq.max_retries = 3
settings.rabbitmq.max_delay = 10
queue_manager = QueueManager(settings)
yield queue_manager

View File

@ -3,7 +3,6 @@ from sys import stdout
from time import sleep
import pika
import pytest
from kn_utils.logging import logger
logger.remove()

View File

@ -7,7 +7,7 @@ from fastapi import FastAPI
from pyinfra.storage.connection import get_storage_for_tenant
from pyinfra.storage.utils import (
download_data_as_specified_in_message,
download_data_bytes_as_specified_in_message,
upload_data_as_specified_in_message,
)
from pyinfra.utils.cipher import encrypt
@ -139,16 +139,6 @@ def payload(payload_type):
}
@pytest.fixture
def expected_data(payload_type):
if payload_type == "target_response_file_path":
return {"data": "success"}
elif payload_type == "dossier_id_file_id":
return {"dossierId": "test", "fileId": "file", "data": "success"}
elif payload_type == "target_file_dict":
return {"file_1": {"data": "success"}, "file_2": {"data": "success"}}
@pytest.mark.parametrize(
"payload_type",
[
@ -160,17 +150,17 @@ def expected_data(payload_type):
)
@pytest.mark.parametrize("storage_backend", ["azure", "s3"], scope="class")
class TestDownloadAndUploadFromMessage:
def test_download_and_upload_from_message(self, storage, payload, expected_data, payload_type):
def test_download_and_upload_from_message(self, storage, payload, payload_type):
storage.clear_bucket()
upload_data = expected_data if payload_type != "target_file_dict" else expected_data["file_1"]
storage.put_object("test/file.target.json.gz", gzip.compress(json.dumps(upload_data).encode()))
result = {"process_result": "success"}
storage_data = {**payload, "data": result}
packed_data = gzip.compress(json.dumps(storage_data).encode())
data = download_data_as_specified_in_message(storage, payload)
storage.put_object("test/file.target.json.gz", packed_data)
assert data == expected_data
upload_data_as_specified_in_message(storage, payload, expected_data)
_ = download_data_bytes_as_specified_in_message(storage, payload)
upload_data_as_specified_in_message(storage, payload, result)
data = json.loads(gzip.decompress(storage.get_object("test/file.response.json.gz")).decode())
assert data == {**payload, "data": expected_data}
assert data == storage_data

View File

@ -1,197 +0,0 @@
import gzip
import json
from pathlib import Path
import pytest
from deepdiff import DeepDiff
from pyinfra.storage.proto_data_loader import ProtoDataLoader
enum = 1
@pytest.fixture
def test_data_dir():
return Path(__file__).parents[1] / "data"
@pytest.fixture
def document_data(request, test_data_dir) -> (str, bytes, dict | list):
doc_type = request.param
# Search for relevant doc_type file pairs - there should be one proto and one json file per document type
input_file_path = next(test_data_dir.glob(f"*.{doc_type}.proto.gz"), None)
target_file_path = next(test_data_dir.glob(f"*.{doc_type}.json.gz"), None)
input_data = input_file_path.read_bytes()
target_data = json.loads(gzip.decompress(target_file_path.read_bytes()))
return input_file_path, input_data, target_data
@pytest.fixture
def proto_data_loader():
return ProtoDataLoader()
@pytest.fixture
def should_match():
return [
"a.DOCUMENT_STRUCTURE.proto.gz",
"a.DOCUMENT_TEXT.proto.gz",
"a.DOCUMENT_PAGES.proto.gz",
"a.DOCUMENT_POSITION.proto.gz",
"b.DOCUMENT_STRUCTURE.proto",
"b.DOCUMENT_TEXT.proto",
"b.DOCUMENT_PAGES.proto",
"b.DOCUMENT_POSITION.proto",
"c.STRUCTURE.proto.gz",
"c.TEXT.proto.gz",
"c.PAGES.proto.gz",
"c.POSITION.proto.gz",
]
@pytest.mark.xfail(
reason="FIXME: The test is not stable, but has to work before we can deploy the code! Right now, we don't have parity between the proto and the json data."
)
# As DOCUMENT_POSITION is a very large file, the test takes forever. If you want to test it, add "DOCUMENT_POSITION" to the list below. - Added per default
@pytest.mark.parametrize("document_data", ["DOCUMENT_STRUCTURE", "DOCUMENT_TEXT", "DOCUMENT_PAGES", "DOCUMENT_POSITION"], indirect=True)
def test_proto_data_loader_end2end(document_data, proto_data_loader):
file_path, data, target = document_data
data = gzip.decompress(data)
loaded_data = proto_data_loader(file_path, data)
loaded_data_str = json.dumps(loaded_data, sort_keys=True)
target_str = json.dumps(target, sort_keys=True)
# If you want to look at the files in more detail uncomment code below
# global enum
# with open(f"input-{enum}.json", "w") as f:
# json.dump(target, f, sort_keys=True, indent=4)
# with open(f"output-{enum}.json", "w") as f:
# json.dump(loaded_data, f, sort_keys=True, indent=4)
# enum += 1
diff = DeepDiff(loaded_data_str, target_str, ignore_order=True)
# FIXME: remove this block when the test is stable
# if diff:
# with open(f"diff_test.json", "w") as f:
# f.write(diff.to_json(indent=4))
assert not diff
def test_proto_data_loader_unknown_document_type(proto_data_loader):
assert not proto_data_loader("unknown_document_type.proto", b"")
def test_proto_data_loader_file_name_matching(proto_data_loader, should_match):
for file_name in should_match:
assert proto_data_loader._match(file_name) is not None
@pytest.mark.parametrize("document_data", ["DOCUMENT_PAGES"], indirect=True)
def test_document_page_types(document_data, proto_data_loader):
# types from document reader
# number: int
# height: int
# width: int
# rotation: int
file_path, data, _ = document_data
data = gzip.decompress(data)
loaded_data = proto_data_loader(file_path, data)
assert isinstance(loaded_data, list)
assert all(isinstance(entry, dict) for entry in loaded_data)
# since all values need to be int anyway we can summarize it
assert all(all(isinstance(value, int) for value in entry.values()) for entry in loaded_data)
@pytest.mark.parametrize("document_data", ["DOCUMENT_POSITION"], indirect=True)
def test_document_position_data_types(document_data, proto_data_loader):
# types from document reader
# id: int
# stringIdxToPositionIdx: list[int]
# positions: list[list[float]]
file_path, data, _ = document_data
data = gzip.decompress(data)
loaded_data = proto_data_loader(file_path, data)
assert isinstance(loaded_data, list)
assert all(isinstance(entry, dict) for entry in loaded_data)
for entry in loaded_data:
assert isinstance(entry["id"], int)
assert isinstance(entry["stringIdxToPositionIdx"], list)
assert isinstance(entry["positions"], list)
assert all(isinstance(position, list) for position in entry["positions"])
assert all(all(isinstance(coordinate, float) for coordinate in position) for position in entry["positions"])
@pytest.mark.parametrize("document_data", ["DOCUMENT_STRUCTURE"], indirect=True)
def test_document_structure_types(document_data, proto_data_loader):
# types from document reader for DocumentStructure
# root: dict
# types from document reader for EntryData
# type: str
# tree_id: list[int]
# atomic_block_ids: list[int]
# page_numbers: list[int]
# properties: dict[str, str]
# children: list[dict]
file_path, data, _ = document_data
data = gzip.decompress(data)
loaded_data = proto_data_loader(file_path, data)
assert isinstance(loaded_data, dict)
assert isinstance(loaded_data["root"], dict)
assert isinstance(loaded_data["root"]["type"], str)
assert isinstance(loaded_data["root"]["treeId"], list)
assert isinstance(loaded_data["root"]["atomicBlockIds"], list)
assert isinstance(loaded_data["root"]["pageNumbers"], list)
assert isinstance(loaded_data["root"]["children"], list)
assert all(isinstance(value, int) for value in loaded_data["root"]["treeId"])
assert all(isinstance(value, int) for value in loaded_data["root"]["atomicBlockIds"])
assert all(isinstance(value, int) for value in loaded_data["root"]["pageNumbers"])
assert all(isinstance(value, dict) for value in loaded_data["root"]["properties"].values())
assert all(
all(isinstance(value, dict) for value in entry.values()) for entry in loaded_data["root"]["properties"].values()
)
assert all(isinstance(value, dict) for value in loaded_data["root"]["children"])
@pytest.mark.parametrize("document_data", ["DOCUMENT_TEXT"], indirect=True)
def test_document_text_data_types(document_data, proto_data_loader):
# types from document reader
# id: int
# page: int
# search_text: str
# number_on_page: int
# start: int
# end: int
# lineBreaks: list[int]
file_path, data, _ = document_data
data = gzip.decompress(data)
loaded_data = proto_data_loader(file_path, data)
assert isinstance(loaded_data, list)
assert all(isinstance(entry, dict) for entry in loaded_data)
for entry in loaded_data:
assert isinstance(entry["id"], int)
assert isinstance(entry["page"], int)
assert isinstance(entry["searchText"], str)
assert isinstance(entry["numberOnPage"], int)
assert isinstance(entry["start"], int)
assert isinstance(entry["end"], int)
assert all(isinstance(value, int) for value in entry["lineBreaks"])