feat: implement proto data loader

This commit is contained in:
Julius Unverfehrt 2024-07-16 16:32:58 +02:00
parent edba6fc4da
commit 9d55b3be89
2 changed files with 133 additions and 0 deletions

View File

@ -0,0 +1,70 @@
import re
from enum import Enum
from pathlib import Path
from google.protobuf.json_format import MessageToDict
from kn_utils.logging import logger
from pyinfra.proto import DocumentStructure_pb2, DocumentTextData_pb2, DocumentPage_pb2, DocumentPositionData_pb2
class ProtoDataLoader:
class DocumentType(Enum):
STRUCTURE = "STRUCTURE"
TEXT = "TEXT"
PAGES = "PAGES"
POSITION = "POSITION"
KEYS_TO_UNPACK = ["documentTextData", "documentPages", "documentPositionData"]
def __init__(self):
self.pattern = self._build_pattern()
def __call__(self, file_name: str | Path, data: bytes) -> dict:
return self._load(file_name, data)
def _load(self, file_name: str | Path, data: bytes) -> dict | list:
file_name = str(file_name)
document_type = self._match(file_name)
logger.info(f"Loading document type: {document_type}")
match document_type:
case self.DocumentType.STRUCTURE.value:
schema = DocumentStructure_pb2
message = schema.DocumentStructure()
case self.DocumentType.TEXT.value:
schema = DocumentTextData_pb2
message = schema.AllDocumentTextData()
case self.DocumentType.PAGES.value:
schema = DocumentPage_pb2
message = schema.AllDocumentPages()
case self.DocumentType.POSITION.value:
schema = DocumentPositionData_pb2
message = schema.AllDocumentPositionData()
case _:
raise ValueError(f"Document type {document_type} not supported")
message.ParseFromString(data)
message_dict = MessageToDict(message)
unpacked_message = self._unpack(message_dict)
return unpacked_message
def _build_pattern(self) -> str:
types = "|".join([doc_type.value for doc_type in self.DocumentType])
pattern = r"\..*(" + types + r").*\.proto\..*"
return pattern
def _match(self, file_name: str) -> str:
return re.search(self.pattern, file_name).group(1)
def _unpack(self, message_dict: dict) -> list | dict:
if len(message_dict) > 1:
return message_dict
for key in message_dict:
if key in self.KEYS_TO_UNPACK:
logger.info(f"Unpacking key: {key}")
return message_dict[key]
return message_dict

View File

@ -0,0 +1,63 @@
import gzip
from pathlib import Path
import pytest
from pyinfra.storage.proto_data_loader import ProtoDataLoader
@pytest.fixture
def test_data_dir():
return Path(__file__).parents[1] / "data"
@pytest.fixture
def document_structure_document(test_data_dir) -> (str, bytes):
file_path = test_data_dir / "72ea04dfdbeb277f37b9eb127efb0896.DOCUMENT_STRUCTURE.proto.gz"
data = file_path.read_bytes()
return file_path, data
@pytest.fixture
def document_text_document(test_data_dir) -> (str, bytes):
file_path = test_data_dir / "72ea04dfdbeb277f37b9eb127efb0896.DOCUMENT_TEXT.proto.gz"
data = file_path.read_bytes()
return file_path, data
@pytest.fixture
def document_pages_document(test_data_dir) -> (str, bytes):
file_path = test_data_dir / "72ea04dfdbeb277f37b9eb127efb0896.DOCUMENT_PAGES.proto.gz"
data = file_path.read_bytes()
return file_path, data
@pytest.fixture
def document_position_document(test_data_dir) -> (str, bytes):
file_path = test_data_dir / "72ea04dfdbeb277f37b9eb127efb0896.DOCUMENT_POSITION.proto.gz"
data = file_path.read_bytes()
return file_path, data
@pytest.fixture
def proto_data_loader():
return ProtoDataLoader()
@pytest.mark.parametrize(
"document_fixture",
[
"document_structure_document",
"document_text_document",
"document_pages_document",
"document_position_document",
],
)
def test_proto_data_loader(document_fixture, request, proto_data_loader):
file_path, data = request.getfixturevalue(document_fixture)
data = gzip.decompress(data)
loaded_data = proto_data_loader(file_path, data)
# TODO: Right now, we don't have access to proto-json pairs to compare the loaded data with the expected data.
# If this becomes available, please update this test to compare the loaded data with the expected data.
assert isinstance(loaded_data, dict) or isinstance(loaded_data, list)