feat: add tests for types of documentreader
This commit is contained in:
parent
67c30a5620
commit
43881de526
@ -64,6 +64,8 @@ class ProtoDataLoader:
|
|||||||
message.ParseFromString(data)
|
message.ParseFromString(data)
|
||||||
message_dict = MessageToDict(message, including_default_value_fields=True)
|
message_dict = MessageToDict(message, including_default_value_fields=True)
|
||||||
message_dict = convert_int64_fields(message_dict)
|
message_dict = convert_int64_fields(message_dict)
|
||||||
|
if document_type == "POSITION":
|
||||||
|
message_dict = transform_positions_to_list(message_dict)
|
||||||
|
|
||||||
return self._unpack(message_dict)
|
return self._unpack(message_dict)
|
||||||
|
|
||||||
@ -93,3 +95,28 @@ def convert_int64_fields(obj):
|
|||||||
elif isinstance(obj, str) and obj.isdigit():
|
elif isinstance(obj, str) and obj.isdigit():
|
||||||
return int(obj)
|
return int(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def transform_positions_to_list(obj: dict | list) -> dict:
|
||||||
|
"""Transforms the repeated fields 'positions' to a lists of lists of floats
|
||||||
|
as expected by DocumentReader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj (dict | list): Proto message dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Proto message dict
|
||||||
|
"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
# Check if 'positions' is in the dictionary and reshape it as list of lists of floats
|
||||||
|
if "positions" in obj and isinstance(obj["positions"], list):
|
||||||
|
obj["positions"] = [pos["value"] for pos in obj["positions"] if isinstance(pos, dict) and "value" in pos]
|
||||||
|
|
||||||
|
# Recursively apply to all nested dictionaries
|
||||||
|
for key, value in obj.items():
|
||||||
|
obj[key] = transform_positions_to_list(value)
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
# Recursively apply to all items in the list
|
||||||
|
obj = [transform_positions_to_list(item) for item in obj]
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|||||||
@ -1,7 +1,10 @@
|
|||||||
import gzip
|
import gzip
|
||||||
import json
|
import json
|
||||||
|
import difflib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from google.protobuf import json_format
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
|
|
||||||
@ -17,6 +20,8 @@ def test_data_dir():
|
|||||||
def document_data(request, test_data_dir) -> (str, bytes, dict | list):
|
def document_data(request, test_data_dir) -> (str, bytes, dict | list):
|
||||||
doc_type = request.param
|
doc_type = request.param
|
||||||
input_file_path = test_data_dir / f"72ea04dfdbeb277f37b9eb127efb0896.{doc_type}.proto.gz"
|
input_file_path = test_data_dir / f"72ea04dfdbeb277f37b9eb127efb0896.{doc_type}.proto.gz"
|
||||||
|
# input_file_path = test_data_dir / f"6ff38b030fa131e8e39bf5598513f981.{doc_type}.proto.gz" # new proto schema
|
||||||
|
# input_file_path = test_data_dir / f"8d1e6798a2c5dc14869e5b3ad8ae501f.{doc_type}.proto.gz"
|
||||||
target_file_path = test_data_dir / f"3f9d3d9f255007de8eff13648321e197.{doc_type}.json.gz"
|
target_file_path = test_data_dir / f"3f9d3d9f255007de8eff13648321e197.{doc_type}.json.gz"
|
||||||
|
|
||||||
input_data = input_file_path.read_bytes()
|
input_data = input_file_path.read_bytes()
|
||||||
@ -58,14 +63,25 @@ def test_proto_data_loader_end2end(document_data, proto_data_loader):
|
|||||||
data = gzip.decompress(data)
|
data = gzip.decompress(data)
|
||||||
loaded_data = proto_data_loader(file_path, data)
|
loaded_data = proto_data_loader(file_path, data)
|
||||||
|
|
||||||
|
# proto_json = json_format.MessageToJson(loaded_data)
|
||||||
|
|
||||||
loaded_data_str = json.dumps(loaded_data, sort_keys=True)
|
loaded_data_str = json.dumps(loaded_data, sort_keys=True)
|
||||||
target_str = json.dumps(target, sort_keys=True)
|
target_str = json.dumps(target, sort_keys=True)
|
||||||
|
|
||||||
|
# diff = difflib.unified_diff(loaded_data_str.splitlines(), target_str.splitlines())
|
||||||
|
|
||||||
|
# for line in diff:
|
||||||
|
# print(line)
|
||||||
|
|
||||||
|
# diff = DeepDiff(loaded_data, target, ignore_order=True)
|
||||||
|
# print(diff)
|
||||||
|
|
||||||
diff = DeepDiff(sorted(loaded_data_str), sorted(target_str), ignore_order=True)
|
diff = DeepDiff(sorted(loaded_data_str), sorted(target_str), ignore_order=True)
|
||||||
|
|
||||||
# FIXME: remove this block when the test is stable
|
# FIXME: remove this block when the test is stable
|
||||||
# if diff:
|
# if diff:
|
||||||
# with open("/tmp/diff.json", "w") as f:
|
# print(diff.to_json(indent=2))
|
||||||
|
# with open(f"diff_{file_path}.json", "w") as f:
|
||||||
# f.write(diff.to_json(indent=2))
|
# f.write(diff.to_json(indent=2))
|
||||||
|
|
||||||
assert not diff
|
assert not diff
|
||||||
@ -78,3 +94,107 @@ def test_proto_data_loader_unknown_document_type(proto_data_loader):
|
|||||||
def test_proto_data_loader_file_name_matching(proto_data_loader, should_match):
|
def test_proto_data_loader_file_name_matching(proto_data_loader, should_match):
|
||||||
for file_name in should_match:
|
for file_name in should_match:
|
||||||
assert proto_data_loader._match(file_name) is not None
|
assert proto_data_loader._match(file_name) is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("document_data", ["DOCUMENT_PAGES"], indirect=True)
|
||||||
|
def test_document_page_types(document_data, proto_data_loader):
|
||||||
|
# types from document reader
|
||||||
|
# number: int
|
||||||
|
# height: int
|
||||||
|
# width: int
|
||||||
|
# rotation: int
|
||||||
|
|
||||||
|
file_path, data, _ = document_data
|
||||||
|
data = gzip.decompress(data)
|
||||||
|
loaded_data = proto_data_loader(file_path, data)
|
||||||
|
|
||||||
|
assert isinstance(loaded_data, list)
|
||||||
|
assert all(isinstance(entry, dict) for entry in loaded_data)
|
||||||
|
|
||||||
|
# since all values need to be int anyway we can summarize it
|
||||||
|
assert all(all(isinstance(value, int) for value in entry.values()) for entry in loaded_data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("document_data", ["DOCUMENT_POSITION"], indirect=True)
|
||||||
|
def test_document_position_data_types(document_data, proto_data_loader):
|
||||||
|
# types from document reader
|
||||||
|
# id: int
|
||||||
|
# stringIdxToPositionIdx: list[int]
|
||||||
|
# positions: list[list[float]]
|
||||||
|
|
||||||
|
file_path, data, _ = document_data
|
||||||
|
data = gzip.decompress(data)
|
||||||
|
loaded_data = proto_data_loader(file_path, data)
|
||||||
|
|
||||||
|
assert isinstance(loaded_data, list)
|
||||||
|
assert all(isinstance(entry, dict) for entry in loaded_data)
|
||||||
|
|
||||||
|
for entry in loaded_data:
|
||||||
|
assert isinstance(entry["id"], int)
|
||||||
|
assert isinstance(entry["stringIdxToPositionIdx"], list)
|
||||||
|
assert isinstance(entry["positions"], list)
|
||||||
|
assert all(isinstance(position, list) for position in entry["positions"])
|
||||||
|
assert all(all(isinstance(coordinate, float) for coordinate in position) for position in entry["positions"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("document_data", ["DOCUMENT_STRUCTURE"], indirect=True)
|
||||||
|
def test_document_structure_types(document_data, proto_data_loader):
|
||||||
|
# types from document reader for DocumentStructure
|
||||||
|
# root: dict
|
||||||
|
|
||||||
|
# types from document reader for EntryData
|
||||||
|
# type: str
|
||||||
|
# tree_id: list[int]
|
||||||
|
# atomic_block_ids: list[int]
|
||||||
|
# page_numbers: list[int]
|
||||||
|
# properties: dict[str, str]
|
||||||
|
# children: list[dict]
|
||||||
|
|
||||||
|
file_path, data, _ = document_data
|
||||||
|
data = gzip.decompress(data)
|
||||||
|
loaded_data = proto_data_loader(file_path, data)
|
||||||
|
|
||||||
|
assert isinstance(loaded_data, dict)
|
||||||
|
assert isinstance(loaded_data["root"], dict)
|
||||||
|
assert isinstance(loaded_data["root"]["type"], str)
|
||||||
|
assert isinstance(loaded_data["root"]["treeId"], list)
|
||||||
|
assert isinstance(loaded_data["root"]["atomicBlockIds"], list)
|
||||||
|
assert isinstance(loaded_data["root"]["pageNumbers"], list)
|
||||||
|
assert isinstance(loaded_data["root"]["children"], list)
|
||||||
|
|
||||||
|
assert all(isinstance(value, int) for value in loaded_data["root"]["treeId"])
|
||||||
|
assert all(isinstance(value, int) for value in loaded_data["root"]["atomicBlockIds"])
|
||||||
|
assert all(isinstance(value, int) for value in loaded_data["root"]["pageNumbers"])
|
||||||
|
assert all(isinstance(value, dict) for value in loaded_data["root"]["properties"].values())
|
||||||
|
assert all(
|
||||||
|
all(isinstance(value, dict) for value in entry.values()) for entry in loaded_data["root"]["properties"].values()
|
||||||
|
)
|
||||||
|
assert all(isinstance(value, dict) for value in loaded_data["root"]["children"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("document_data", ["DOCUMENT_TEXT"], indirect=True)
|
||||||
|
def test_document_text_data_types(document_data, proto_data_loader):
|
||||||
|
# types from document reader
|
||||||
|
# id: int
|
||||||
|
# page: int
|
||||||
|
# search_text: str
|
||||||
|
# number_on_page: int
|
||||||
|
# start: int
|
||||||
|
# end: int
|
||||||
|
# lineBreaks: list[int]
|
||||||
|
|
||||||
|
file_path, data, _ = document_data
|
||||||
|
data = gzip.decompress(data)
|
||||||
|
loaded_data = proto_data_loader(file_path, data)
|
||||||
|
|
||||||
|
assert isinstance(loaded_data, list)
|
||||||
|
assert all(isinstance(entry, dict) for entry in loaded_data)
|
||||||
|
|
||||||
|
for entry in loaded_data:
|
||||||
|
assert isinstance(entry["id"], int)
|
||||||
|
assert isinstance(entry["page"], int)
|
||||||
|
assert isinstance(entry["searchText"], str)
|
||||||
|
assert isinstance(entry["numberOnPage"], int)
|
||||||
|
assert isinstance(entry["start"], int)
|
||||||
|
assert isinstance(entry["end"], int)
|
||||||
|
assert all(isinstance(value, int) for value in entry["lineBreaks"])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user