diff --git a/pyinfra/storage/proto_data_loader.py b/pyinfra/storage/proto_data_loader.py new file mode 100644 index 0000000..e6ca643 --- /dev/null +++ b/pyinfra/storage/proto_data_loader.py @@ -0,0 +1,70 @@ +import re +from enum import Enum +from pathlib import Path + +from google.protobuf.json_format import MessageToDict +from kn_utils.logging import logger + +from pyinfra.proto import DocumentStructure_pb2, DocumentTextData_pb2, DocumentPage_pb2, DocumentPositionData_pb2 + + +class ProtoDataLoader: + class DocumentType(Enum): + STRUCTURE = "STRUCTURE" + TEXT = "TEXT" + PAGES = "PAGES" + POSITION = "POSITION" + + KEYS_TO_UNPACK = ["documentTextData", "documentPages", "documentPositionData"] + + def __init__(self): + self.pattern = self._build_pattern() + + def __call__(self, file_name: str | Path, data: bytes) -> dict: + return self._load(file_name, data) + + def _load(self, file_name: str | Path, data: bytes) -> dict | list: + file_name = str(file_name) + document_type = self._match(file_name) + + logger.info(f"Loading document type: {document_type}") + match document_type: + case self.DocumentType.STRUCTURE.value: + schema = DocumentStructure_pb2 + message = schema.DocumentStructure() + case self.DocumentType.TEXT.value: + schema = DocumentTextData_pb2 + message = schema.AllDocumentTextData() + case self.DocumentType.PAGES.value: + schema = DocumentPage_pb2 + message = schema.AllDocumentPages() + case self.DocumentType.POSITION.value: + schema = DocumentPositionData_pb2 + message = schema.AllDocumentPositionData() + case _: + raise ValueError(f"Document type {document_type} not supported") + + message.ParseFromString(data) + message_dict = MessageToDict(message) + unpacked_message = self._unpack(message_dict) + + return unpacked_message + + def _build_pattern(self) -> str: + types = "|".join([doc_type.value for doc_type in self.DocumentType]) + pattern = r"\..*(" + types + r").*\.proto\..*" + return pattern + + def _match(self, file_name: str) -> str: + return re.search(self.pattern, file_name).group(1) + + def _unpack(self, message_dict: dict) -> list | dict: + if len(message_dict) > 1: + return message_dict + + for key in message_dict: + if key in self.KEYS_TO_UNPACK: + logger.info(f"Unpacking key: {key}") + return message_dict[key] + + return message_dict diff --git a/tests/unit_test/data_loader_test.py b/tests/unit_test/data_loader_test.py new file mode 100644 index 0000000..74248e4 --- /dev/null +++ b/tests/unit_test/data_loader_test.py @@ -0,0 +1,63 @@ +import gzip +from pathlib import Path + +import pytest + +from pyinfra.storage.proto_data_loader import ProtoDataLoader + + +@pytest.fixture +def test_data_dir(): + return Path(__file__).parents[1] / "data" + + +@pytest.fixture +def document_structure_document(test_data_dir) -> (str, bytes): + file_path = test_data_dir / "72ea04dfdbeb277f37b9eb127efb0896.DOCUMENT_STRUCTURE.proto.gz" + data = file_path.read_bytes() + return file_path, data + + +@pytest.fixture +def document_text_document(test_data_dir) -> (str, bytes): + file_path = test_data_dir / "72ea04dfdbeb277f37b9eb127efb0896.DOCUMENT_TEXT.proto.gz" + data = file_path.read_bytes() + return file_path, data + + +@pytest.fixture +def document_pages_document(test_data_dir) -> (str, bytes): + file_path = test_data_dir / "72ea04dfdbeb277f37b9eb127efb0896.DOCUMENT_PAGES.proto.gz" + data = file_path.read_bytes() + return file_path, data + + +@pytest.fixture +def document_position_document(test_data_dir) -> (str, bytes): + file_path = test_data_dir / "72ea04dfdbeb277f37b9eb127efb0896.DOCUMENT_POSITION.proto.gz" + data = file_path.read_bytes() + return file_path, data + + +@pytest.fixture +def proto_data_loader(): + return ProtoDataLoader() + + +@pytest.mark.parametrize( + "document_fixture", + [ + "document_structure_document", + "document_text_document", + "document_pages_document", + "document_position_document", + ], +) +def test_proto_data_loader(document_fixture, request, proto_data_loader): + file_path, data = request.getfixturevalue(document_fixture) + data = gzip.decompress(data) + loaded_data = proto_data_loader(file_path, data) + + # TODO: Right now, we don't have access to proto-json pairs to compare the loaded data with the expected data. + # If this becomes available, please update this test to compare the loaded data with the expected data. + assert isinstance(loaded_data, dict) or isinstance(loaded_data, list)