feat: integrate proto data loader in pipeline

This commit is contained in:
Julius Unverfehrt 2024-07-16 17:34:26 +02:00
parent 9d55b3be89
commit 0d232226fd
3 changed files with 78 additions and 36 deletions

View File

@ -9,16 +9,38 @@ from pyinfra.proto import DocumentStructure_pb2, DocumentTextData_pb2, DocumentP
class ProtoDataLoader:
"""Loads proto data from a file and returns it as a dictionary or list.
The loader is a singleton and should be used as a callable. The file name and byte data are passed as arguments.
The document type is determined based on the file name and the data is returned as a dictionary or list, depending
on the document type.
The DocumentType enum contains all supported document types and their corresponding proto schema.
KEYS_TO_UNPACK contains the keys that should be unpacked from the message dictionary. Keys are unpacked if the
message dictionary contains only one key. This behaviour is necessary since lists are wrapped in a dictionary.
"""
_instance = None
_pattern = None
class DocumentType(Enum):
STRUCTURE = "STRUCTURE"
TEXT = "TEXT"
PAGES = "PAGES"
POSITION = "POSITION"
STRUCTURE = (DocumentStructure_pb2.DocumentStructure, "DocumentStructure")
TEXT = (DocumentTextData_pb2.AllDocumentTextData, "AllDocumentTextData")
PAGES = (DocumentPage_pb2.AllDocumentPages, "AllDocumentPages")
POSITION = (DocumentPositionData_pb2.AllDocumentPositionData, "AllDocumentPositionData")
KEYS_TO_UNPACK = ["documentTextData", "documentPages", "documentPositionData"]
def __init__(self):
self.pattern = self._build_pattern()
@classmethod
def _build_pattern(cls) -> re.Pattern:
types = "|".join([dt.name for dt in cls.DocumentType])
return re.compile(rf"\..*({types}).*\.proto.*")
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._pattern = cls._build_pattern()
return cls._instance
def __call__(self, file_name: str | Path, data: bytes) -> dict:
return self._load(file_name, data)
@ -27,44 +49,28 @@ class ProtoDataLoader:
file_name = str(file_name)
document_type = self._match(file_name)
logger.info(f"Loading document type: {document_type}")
match document_type:
case self.DocumentType.STRUCTURE.value:
schema = DocumentStructure_pb2
message = schema.DocumentStructure()
case self.DocumentType.TEXT.value:
schema = DocumentTextData_pb2
message = schema.AllDocumentTextData()
case self.DocumentType.PAGES.value:
schema = DocumentPage_pb2
message = schema.AllDocumentPages()
case self.DocumentType.POSITION.value:
schema = DocumentPositionData_pb2
message = schema.AllDocumentPositionData()
case _:
raise ValueError(f"Document type {document_type} not supported")
if not document_type:
raise ValueError(f"Unknown document type: {file_name}, supported types: {self.DocumentType}")
logger.debug(f"Loading document type: {document_type}")
schema, _ = self.DocumentType[document_type].value
message = schema()
message.ParseFromString(data)
message_dict = MessageToDict(message)
unpacked_message = self._unpack(message_dict)
return unpacked_message
return self._unpack(message_dict)
def _build_pattern(self) -> str:
types = "|".join([doc_type.value for doc_type in self.DocumentType])
pattern = r"\..*(" + types + r").*\.proto\..*"
return pattern
def _match(self, file_name: str) -> str:
return re.search(self.pattern, file_name).group(1)
def _match(self, file_name: str) -> str | None:
match = self._pattern.search(file_name)
return match.group(1) if match else None
def _unpack(self, message_dict: dict) -> list | dict:
if len(message_dict) > 1:
return message_dict
for key in message_dict:
if key in self.KEYS_TO_UNPACK:
logger.info(f"Unpacking key: {key}")
for key in self.KEYS_TO_UNPACK:
if key in message_dict:
logger.debug(f"Unpacking key: {key}")
return message_dict[key]
return message_dict

View File

@ -6,6 +6,7 @@ from typing import Union
from kn_utils.logging import logger
from pydantic import BaseModel, ValidationError
from pyinfra.storage.proto_data_loader import ProtoDataLoader
from pyinfra.storage.storages.storage import Storage
@ -80,7 +81,14 @@ def _download_single_file(file_path: str, storage: Storage) -> bytes:
data = storage.get_object(file_path)
data = gzip.decompress(data) if ".gz" in file_path else data
data = json.loads(data.decode("utf-8")) if ".json" in file_path else data
if ".json" in file_path:
data = json.loads(data.decode("utf-8"))
elif ".proto" in file_path:
data = ProtoDataLoader()(file_path, data)
else:
pass # identity for other file types
logger.info(f"Downloaded {file_path} from storage.")
return data

View File

@ -44,6 +44,24 @@ def proto_data_loader():
return ProtoDataLoader()
@pytest.fixture
def should_match():
return [
"a.DOCUMENT_STRUCTURE.proto.gz",
"a.DOCUMENT_TEXT.proto.gz",
"a.DOCUMENT_PAGES.proto.gz",
"a.DOCUMENT_POSITION.proto.gz",
"b.DOCUMENT_STRUCTURE.proto",
"b.DOCUMENT_TEXT.proto",
"b.DOCUMENT_PAGES.proto",
"b.DOCUMENT_POSITION.proto",
"c.STRUCTURE.proto.gz",
"c.TEXT.proto.gz",
"c.PAGES.proto.gz",
"c.POSITION.proto.gz",
]
@pytest.mark.parametrize(
"document_fixture",
[
@ -53,7 +71,7 @@ def proto_data_loader():
"document_position_document",
],
)
def test_proto_data_loader(document_fixture, request, proto_data_loader):
def test_proto_data_loader_end2end(document_fixture, request, proto_data_loader):
file_path, data = request.getfixturevalue(document_fixture)
data = gzip.decompress(data)
loaded_data = proto_data_loader(file_path, data)
@ -61,3 +79,13 @@ def test_proto_data_loader(document_fixture, request, proto_data_loader):
# TODO: Right now, we don't have access to proto-json pairs to compare the loaded data with the expected data.
# If this becomes available, please update this test to compare the loaded data with the expected data.
assert isinstance(loaded_data, dict) or isinstance(loaded_data, list)
def test_proto_data_loader_unknown_document_type(proto_data_loader):
with pytest.raises(ValueError):
proto_data_loader("unknown_document_type.proto", b"")
def test_proto_data_loader_file_name_matching(proto_data_loader, should_match):
for file_name in should_match:
assert proto_data_loader._match(file_name) is not None