From 43881de5264b0f594f281dfd74977fc045f748c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20K=C3=B6ssler?= Date: Fri, 20 Sep 2024 16:42:55 +0200 Subject: [PATCH] feat: add tests for types of documentreader --- pyinfra/storage/proto_data_loader.py | 27 +++++ tests/unit_test/proto_data_loader_test.py | 122 +++++++++++++++++++++- 2 files changed, 148 insertions(+), 1 deletion(-) diff --git a/pyinfra/storage/proto_data_loader.py b/pyinfra/storage/proto_data_loader.py index 32f2978..1bb6d9c 100644 --- a/pyinfra/storage/proto_data_loader.py +++ b/pyinfra/storage/proto_data_loader.py @@ -64,6 +64,8 @@ class ProtoDataLoader: message.ParseFromString(data) message_dict = MessageToDict(message, including_default_value_fields=True) message_dict = convert_int64_fields(message_dict) + if document_type == "POSITION": + message_dict = transform_positions_to_list(message_dict) return self._unpack(message_dict) @@ -93,3 +95,28 @@ def convert_int64_fields(obj): elif isinstance(obj, str) and obj.isdigit(): return int(obj) return obj + + +def transform_positions_to_list(obj: dict | list) -> dict: + """Transforms the repeated fields 'positions' to a lists of lists of floats + as expected by DocumentReader. + + Args: + obj (dict | list): Proto message dict + + Returns: + dict: Proto message dict + """ + if isinstance(obj, dict): + # Check if 'positions' is in the dictionary and reshape it as list of lists of floats + if "positions" in obj and isinstance(obj["positions"], list): + obj["positions"] = [pos["value"] for pos in obj["positions"] if isinstance(pos, dict) and "value" in pos] + + # Recursively apply to all nested dictionaries + for key, value in obj.items(): + obj[key] = transform_positions_to_list(value) + elif isinstance(obj, list): + # Recursively apply to all items in the list + obj = [transform_positions_to_list(item) for item in obj] + + return obj diff --git a/tests/unit_test/proto_data_loader_test.py b/tests/unit_test/proto_data_loader_test.py index e8dc9c1..0b12951 100644 --- a/tests/unit_test/proto_data_loader_test.py +++ b/tests/unit_test/proto_data_loader_test.py @@ -1,7 +1,10 @@ import gzip import json +import difflib from pathlib import Path +from google.protobuf import json_format + import pytest from deepdiff import DeepDiff @@ -17,6 +20,8 @@ def test_data_dir(): def document_data(request, test_data_dir) -> (str, bytes, dict | list): doc_type = request.param input_file_path = test_data_dir / f"72ea04dfdbeb277f37b9eb127efb0896.{doc_type}.proto.gz" + # input_file_path = test_data_dir / f"6ff38b030fa131e8e39bf5598513f981.{doc_type}.proto.gz" # new proto schema + # input_file_path = test_data_dir / f"8d1e6798a2c5dc14869e5b3ad8ae501f.{doc_type}.proto.gz" target_file_path = test_data_dir / f"3f9d3d9f255007de8eff13648321e197.{doc_type}.json.gz" input_data = input_file_path.read_bytes() @@ -58,14 +63,25 @@ def test_proto_data_loader_end2end(document_data, proto_data_loader): data = gzip.decompress(data) loaded_data = proto_data_loader(file_path, data) + # proto_json = json_format.MessageToJson(loaded_data) + loaded_data_str = json.dumps(loaded_data, sort_keys=True) target_str = json.dumps(target, sort_keys=True) + # diff = difflib.unified_diff(loaded_data_str.splitlines(), target_str.splitlines()) + + # for line in diff: + # print(line) + + # diff = DeepDiff(loaded_data, target, ignore_order=True) + # print(diff) + diff = DeepDiff(sorted(loaded_data_str), sorted(target_str), ignore_order=True) # FIXME: remove this block when the test is stable # if diff: - # with open("/tmp/diff.json", "w") as f: + # print(diff.to_json(indent=2)) + # with open(f"diff_{file_path}.json", "w") as f: # f.write(diff.to_json(indent=2)) assert not diff @@ -78,3 +94,107 @@ def test_proto_data_loader_unknown_document_type(proto_data_loader): def test_proto_data_loader_file_name_matching(proto_data_loader, should_match): for file_name in should_match: assert proto_data_loader._match(file_name) is not None + + +@pytest.mark.parametrize("document_data", ["DOCUMENT_PAGES"], indirect=True) +def test_document_page_types(document_data, proto_data_loader): + # types from document reader + # number: int + # height: int + # width: int + # rotation: int + + file_path, data, _ = document_data + data = gzip.decompress(data) + loaded_data = proto_data_loader(file_path, data) + + assert isinstance(loaded_data, list) + assert all(isinstance(entry, dict) for entry in loaded_data) + + # since all values need to be int anyway we can summarize it + assert all(all(isinstance(value, int) for value in entry.values()) for entry in loaded_data) + + +@pytest.mark.parametrize("document_data", ["DOCUMENT_POSITION"], indirect=True) +def test_document_position_data_types(document_data, proto_data_loader): + # types from document reader + # id: int + # stringIdxToPositionIdx: list[int] + # positions: list[list[float]] + + file_path, data, _ = document_data + data = gzip.decompress(data) + loaded_data = proto_data_loader(file_path, data) + + assert isinstance(loaded_data, list) + assert all(isinstance(entry, dict) for entry in loaded_data) + + for entry in loaded_data: + assert isinstance(entry["id"], int) + assert isinstance(entry["stringIdxToPositionIdx"], list) + assert isinstance(entry["positions"], list) + assert all(isinstance(position, list) for position in entry["positions"]) + assert all(all(isinstance(coordinate, float) for coordinate in position) for position in entry["positions"]) + + +@pytest.mark.parametrize("document_data", ["DOCUMENT_STRUCTURE"], indirect=True) +def test_document_structure_types(document_data, proto_data_loader): + # types from document reader for DocumentStructure + # root: dict + + # types from document reader for EntryData + # type: str + # tree_id: list[int] + # atomic_block_ids: list[int] + # page_numbers: list[int] + # properties: dict[str, str] + # children: list[dict] + + file_path, data, _ = document_data + data = gzip.decompress(data) + loaded_data = proto_data_loader(file_path, data) + + assert isinstance(loaded_data, dict) + assert isinstance(loaded_data["root"], dict) + assert isinstance(loaded_data["root"]["type"], str) + assert isinstance(loaded_data["root"]["treeId"], list) + assert isinstance(loaded_data["root"]["atomicBlockIds"], list) + assert isinstance(loaded_data["root"]["pageNumbers"], list) + assert isinstance(loaded_data["root"]["children"], list) + + assert all(isinstance(value, int) for value in loaded_data["root"]["treeId"]) + assert all(isinstance(value, int) for value in loaded_data["root"]["atomicBlockIds"]) + assert all(isinstance(value, int) for value in loaded_data["root"]["pageNumbers"]) + assert all(isinstance(value, dict) for value in loaded_data["root"]["properties"].values()) + assert all( + all(isinstance(value, dict) for value in entry.values()) for entry in loaded_data["root"]["properties"].values() + ) + assert all(isinstance(value, dict) for value in loaded_data["root"]["children"]) + + +@pytest.mark.parametrize("document_data", ["DOCUMENT_TEXT"], indirect=True) +def test_document_text_data_types(document_data, proto_data_loader): + # types from document reader + # id: int + # page: int + # search_text: str + # number_on_page: int + # start: int + # end: int + # lineBreaks: list[int] + + file_path, data, _ = document_data + data = gzip.decompress(data) + loaded_data = proto_data_loader(file_path, data) + + assert isinstance(loaded_data, list) + assert all(isinstance(entry, dict) for entry in loaded_data) + + for entry in loaded_data: + assert isinstance(entry["id"], int) + assert isinstance(entry["page"], int) + assert isinstance(entry["searchText"], str) + assert isinstance(entry["numberOnPage"], int) + assert isinstance(entry["start"], int) + assert isinstance(entry["end"], int) + assert all(isinstance(value, int) for value in entry["lineBreaks"])