feat: implement proto data loader
This commit is contained in:
parent
edba6fc4da
commit
9d55b3be89
70
pyinfra/storage/proto_data_loader.py
Normal file
70
pyinfra/storage/proto_data_loader.py
Normal file
@ -0,0 +1,70 @@
|
||||
import re
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
from kn_utils.logging import logger
|
||||
|
||||
from pyinfra.proto import DocumentStructure_pb2, DocumentTextData_pb2, DocumentPage_pb2, DocumentPositionData_pb2
|
||||
|
||||
|
||||
class ProtoDataLoader:
|
||||
class DocumentType(Enum):
|
||||
STRUCTURE = "STRUCTURE"
|
||||
TEXT = "TEXT"
|
||||
PAGES = "PAGES"
|
||||
POSITION = "POSITION"
|
||||
|
||||
KEYS_TO_UNPACK = ["documentTextData", "documentPages", "documentPositionData"]
|
||||
|
||||
def __init__(self):
|
||||
self.pattern = self._build_pattern()
|
||||
|
||||
def __call__(self, file_name: str | Path, data: bytes) -> dict:
|
||||
return self._load(file_name, data)
|
||||
|
||||
def _load(self, file_name: str | Path, data: bytes) -> dict | list:
|
||||
file_name = str(file_name)
|
||||
document_type = self._match(file_name)
|
||||
|
||||
logger.info(f"Loading document type: {document_type}")
|
||||
match document_type:
|
||||
case self.DocumentType.STRUCTURE.value:
|
||||
schema = DocumentStructure_pb2
|
||||
message = schema.DocumentStructure()
|
||||
case self.DocumentType.TEXT.value:
|
||||
schema = DocumentTextData_pb2
|
||||
message = schema.AllDocumentTextData()
|
||||
case self.DocumentType.PAGES.value:
|
||||
schema = DocumentPage_pb2
|
||||
message = schema.AllDocumentPages()
|
||||
case self.DocumentType.POSITION.value:
|
||||
schema = DocumentPositionData_pb2
|
||||
message = schema.AllDocumentPositionData()
|
||||
case _:
|
||||
raise ValueError(f"Document type {document_type} not supported")
|
||||
|
||||
message.ParseFromString(data)
|
||||
message_dict = MessageToDict(message)
|
||||
unpacked_message = self._unpack(message_dict)
|
||||
|
||||
return unpacked_message
|
||||
|
||||
def _build_pattern(self) -> str:
|
||||
types = "|".join([doc_type.value for doc_type in self.DocumentType])
|
||||
pattern = r"\..*(" + types + r").*\.proto\..*"
|
||||
return pattern
|
||||
|
||||
def _match(self, file_name: str) -> str:
|
||||
return re.search(self.pattern, file_name).group(1)
|
||||
|
||||
def _unpack(self, message_dict: dict) -> list | dict:
|
||||
if len(message_dict) > 1:
|
||||
return message_dict
|
||||
|
||||
for key in message_dict:
|
||||
if key in self.KEYS_TO_UNPACK:
|
||||
logger.info(f"Unpacking key: {key}")
|
||||
return message_dict[key]
|
||||
|
||||
return message_dict
|
||||
63
tests/unit_test/data_loader_test.py
Normal file
63
tests/unit_test/data_loader_test.py
Normal file
@ -0,0 +1,63 @@
|
||||
import gzip
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from pyinfra.storage.proto_data_loader import ProtoDataLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_data_dir():
|
||||
return Path(__file__).parents[1] / "data"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_structure_document(test_data_dir) -> (str, bytes):
|
||||
file_path = test_data_dir / "72ea04dfdbeb277f37b9eb127efb0896.DOCUMENT_STRUCTURE.proto.gz"
|
||||
data = file_path.read_bytes()
|
||||
return file_path, data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_text_document(test_data_dir) -> (str, bytes):
|
||||
file_path = test_data_dir / "72ea04dfdbeb277f37b9eb127efb0896.DOCUMENT_TEXT.proto.gz"
|
||||
data = file_path.read_bytes()
|
||||
return file_path, data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_pages_document(test_data_dir) -> (str, bytes):
|
||||
file_path = test_data_dir / "72ea04dfdbeb277f37b9eb127efb0896.DOCUMENT_PAGES.proto.gz"
|
||||
data = file_path.read_bytes()
|
||||
return file_path, data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_position_document(test_data_dir) -> (str, bytes):
|
||||
file_path = test_data_dir / "72ea04dfdbeb277f37b9eb127efb0896.DOCUMENT_POSITION.proto.gz"
|
||||
data = file_path.read_bytes()
|
||||
return file_path, data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def proto_data_loader():
|
||||
return ProtoDataLoader()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"document_fixture",
|
||||
[
|
||||
"document_structure_document",
|
||||
"document_text_document",
|
||||
"document_pages_document",
|
||||
"document_position_document",
|
||||
],
|
||||
)
|
||||
def test_proto_data_loader(document_fixture, request, proto_data_loader):
|
||||
file_path, data = request.getfixturevalue(document_fixture)
|
||||
data = gzip.decompress(data)
|
||||
loaded_data = proto_data_loader(file_path, data)
|
||||
|
||||
# TODO: Right now, we don't have access to proto-json pairs to compare the loaded data with the expected data.
|
||||
# If this becomes available, please update this test to compare the loaded data with the expected data.
|
||||
assert isinstance(loaded_data, dict) or isinstance(loaded_data, list)
|
||||
Loading…
x
Reference in New Issue
Block a user