feat: integrate proto data loader in pipeline
This commit is contained in:
parent
9d55b3be89
commit
0d232226fd
@ -9,16 +9,38 @@ from pyinfra.proto import DocumentStructure_pb2, DocumentTextData_pb2, DocumentP
|
||||
|
||||
|
||||
class ProtoDataLoader:
|
||||
"""Loads proto data from a file and returns it as a dictionary or list.
|
||||
|
||||
The loader is a singleton and should be used as a callable. The file name and byte data are passed as arguments.
|
||||
|
||||
The document type is determined based on the file name and the data is returned as a dictionary or list, depending
|
||||
on the document type.
|
||||
The DocumentType enum contains all supported document types and their corresponding proto schema.
|
||||
KEYS_TO_UNPACK contains the keys that should be unpacked from the message dictionary. Keys are unpacked if the
|
||||
message dictionary contains only one key. This behaviour is necessary since lists are wrapped in a dictionary.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_pattern = None
|
||||
|
||||
class DocumentType(Enum):
|
||||
STRUCTURE = "STRUCTURE"
|
||||
TEXT = "TEXT"
|
||||
PAGES = "PAGES"
|
||||
POSITION = "POSITION"
|
||||
STRUCTURE = (DocumentStructure_pb2.DocumentStructure, "DocumentStructure")
|
||||
TEXT = (DocumentTextData_pb2.AllDocumentTextData, "AllDocumentTextData")
|
||||
PAGES = (DocumentPage_pb2.AllDocumentPages, "AllDocumentPages")
|
||||
POSITION = (DocumentPositionData_pb2.AllDocumentPositionData, "AllDocumentPositionData")
|
||||
|
||||
KEYS_TO_UNPACK = ["documentTextData", "documentPages", "documentPositionData"]
|
||||
|
||||
def __init__(self):
|
||||
self.pattern = self._build_pattern()
|
||||
@classmethod
|
||||
def _build_pattern(cls) -> re.Pattern:
|
||||
types = "|".join([dt.name for dt in cls.DocumentType])
|
||||
return re.compile(rf"\..*({types}).*\.proto.*")
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._pattern = cls._build_pattern()
|
||||
return cls._instance
|
||||
|
||||
def __call__(self, file_name: str | Path, data: bytes) -> dict:
|
||||
return self._load(file_name, data)
|
||||
@ -27,44 +49,28 @@ class ProtoDataLoader:
|
||||
file_name = str(file_name)
|
||||
document_type = self._match(file_name)
|
||||
|
||||
logger.info(f"Loading document type: {document_type}")
|
||||
match document_type:
|
||||
case self.DocumentType.STRUCTURE.value:
|
||||
schema = DocumentStructure_pb2
|
||||
message = schema.DocumentStructure()
|
||||
case self.DocumentType.TEXT.value:
|
||||
schema = DocumentTextData_pb2
|
||||
message = schema.AllDocumentTextData()
|
||||
case self.DocumentType.PAGES.value:
|
||||
schema = DocumentPage_pb2
|
||||
message = schema.AllDocumentPages()
|
||||
case self.DocumentType.POSITION.value:
|
||||
schema = DocumentPositionData_pb2
|
||||
message = schema.AllDocumentPositionData()
|
||||
case _:
|
||||
raise ValueError(f"Document type {document_type} not supported")
|
||||
if not document_type:
|
||||
raise ValueError(f"Unknown document type: {file_name}, supported types: {self.DocumentType}")
|
||||
|
||||
logger.debug(f"Loading document type: {document_type}")
|
||||
schema, _ = self.DocumentType[document_type].value
|
||||
message = schema()
|
||||
message.ParseFromString(data)
|
||||
message_dict = MessageToDict(message)
|
||||
unpacked_message = self._unpack(message_dict)
|
||||
|
||||
return unpacked_message
|
||||
return self._unpack(message_dict)
|
||||
|
||||
def _build_pattern(self) -> str:
|
||||
types = "|".join([doc_type.value for doc_type in self.DocumentType])
|
||||
pattern = r"\..*(" + types + r").*\.proto\..*"
|
||||
return pattern
|
||||
|
||||
def _match(self, file_name: str) -> str:
|
||||
return re.search(self.pattern, file_name).group(1)
|
||||
def _match(self, file_name: str) -> str | None:
|
||||
match = self._pattern.search(file_name)
|
||||
return match.group(1) if match else None
|
||||
|
||||
def _unpack(self, message_dict: dict) -> list | dict:
|
||||
if len(message_dict) > 1:
|
||||
return message_dict
|
||||
|
||||
for key in message_dict:
|
||||
if key in self.KEYS_TO_UNPACK:
|
||||
logger.info(f"Unpacking key: {key}")
|
||||
for key in self.KEYS_TO_UNPACK:
|
||||
if key in message_dict:
|
||||
logger.debug(f"Unpacking key: {key}")
|
||||
return message_dict[key]
|
||||
|
||||
return message_dict
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Union
|
||||
from kn_utils.logging import logger
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from pyinfra.storage.proto_data_loader import ProtoDataLoader
|
||||
from pyinfra.storage.storages.storage import Storage
|
||||
|
||||
|
||||
@ -80,7 +81,14 @@ def _download_single_file(file_path: str, storage: Storage) -> bytes:
|
||||
data = storage.get_object(file_path)
|
||||
|
||||
data = gzip.decompress(data) if ".gz" in file_path else data
|
||||
data = json.loads(data.decode("utf-8")) if ".json" in file_path else data
|
||||
|
||||
if ".json" in file_path:
|
||||
data = json.loads(data.decode("utf-8"))
|
||||
elif ".proto" in file_path:
|
||||
data = ProtoDataLoader()(file_path, data)
|
||||
else:
|
||||
pass # identity for other file types
|
||||
|
||||
logger.info(f"Downloaded {file_path} from storage.")
|
||||
|
||||
return data
|
||||
|
||||
@ -44,6 +44,24 @@ def proto_data_loader():
|
||||
return ProtoDataLoader()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def should_match():
|
||||
return [
|
||||
"a.DOCUMENT_STRUCTURE.proto.gz",
|
||||
"a.DOCUMENT_TEXT.proto.gz",
|
||||
"a.DOCUMENT_PAGES.proto.gz",
|
||||
"a.DOCUMENT_POSITION.proto.gz",
|
||||
"b.DOCUMENT_STRUCTURE.proto",
|
||||
"b.DOCUMENT_TEXT.proto",
|
||||
"b.DOCUMENT_PAGES.proto",
|
||||
"b.DOCUMENT_POSITION.proto",
|
||||
"c.STRUCTURE.proto.gz",
|
||||
"c.TEXT.proto.gz",
|
||||
"c.PAGES.proto.gz",
|
||||
"c.POSITION.proto.gz",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"document_fixture",
|
||||
[
|
||||
@ -53,7 +71,7 @@ def proto_data_loader():
|
||||
"document_position_document",
|
||||
],
|
||||
)
|
||||
def test_proto_data_loader(document_fixture, request, proto_data_loader):
|
||||
def test_proto_data_loader_end2end(document_fixture, request, proto_data_loader):
|
||||
file_path, data = request.getfixturevalue(document_fixture)
|
||||
data = gzip.decompress(data)
|
||||
loaded_data = proto_data_loader(file_path, data)
|
||||
@ -61,3 +79,13 @@ def test_proto_data_loader(document_fixture, request, proto_data_loader):
|
||||
# TODO: Right now, we don't have access to proto-json pairs to compare the loaded data with the expected data.
|
||||
# If this becomes available, please update this test to compare the loaded data with the expected data.
|
||||
assert isinstance(loaded_data, dict) or isinstance(loaded_data, list)
|
||||
|
||||
|
||||
def test_proto_data_loader_unknown_document_type(proto_data_loader):
|
||||
with pytest.raises(ValueError):
|
||||
proto_data_loader("unknown_document_type.proto", b"")
|
||||
|
||||
|
||||
def test_proto_data_loader_file_name_matching(proto_data_loader, should_match):
|
||||
for file_name in should_match:
|
||||
assert proto_data_loader._match(file_name) is not None
|
||||
Loading…
x
Reference in New Issue
Block a user