From 0d232226fd2d296d21ef6066e0fb45c49814ee3e Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Tue, 16 Jul 2024 17:34:26 +0200 Subject: [PATCH] feat: integrate proto data loader in pipeline --- pyinfra/storage/proto_data_loader.py | 74 ++++++++++--------- pyinfra/storage/utils.py | 10 ++- ...ader_test.py => proto_data_loader_test.py} | 30 +++++++- 3 files changed, 78 insertions(+), 36 deletions(-) rename tests/unit_test/{data_loader_test.py => proto_data_loader_test.py} (67%) diff --git a/pyinfra/storage/proto_data_loader.py b/pyinfra/storage/proto_data_loader.py index e6ca643..1663a35 100644 --- a/pyinfra/storage/proto_data_loader.py +++ b/pyinfra/storage/proto_data_loader.py @@ -9,16 +9,38 @@ from pyinfra.proto import DocumentStructure_pb2, DocumentTextData_pb2, DocumentP class ProtoDataLoader: + """Loads proto data from a file and returns it as a dictionary or list. + + The loader is a singleton and should be used as a callable. The file name and byte data are passed as arguments. + + The document type is determined based on the file name and the data is returned as a dictionary or list, depending + on the document type. + The DocumentType enum contains all supported document types and their corresponding proto schema. + KEYS_TO_UNPACK contains the keys that should be unpacked from the message dictionary. Keys are unpacked if the + message dictionary contains only one key. This behaviour is necessary since lists are wrapped in a dictionary. + """ + + _instance = None + _pattern = None + class DocumentType(Enum): - STRUCTURE = "STRUCTURE" - TEXT = "TEXT" - PAGES = "PAGES" - POSITION = "POSITION" + STRUCTURE = (DocumentStructure_pb2.DocumentStructure, "DocumentStructure") + TEXT = (DocumentTextData_pb2.AllDocumentTextData, "AllDocumentTextData") + PAGES = (DocumentPage_pb2.AllDocumentPages, "AllDocumentPages") + POSITION = (DocumentPositionData_pb2.AllDocumentPositionData, "AllDocumentPositionData") KEYS_TO_UNPACK = ["documentTextData", "documentPages", "documentPositionData"] - def __init__(self): - self.pattern = self._build_pattern() + @classmethod + def _build_pattern(cls) -> re.Pattern: + types = "|".join([dt.name for dt in cls.DocumentType]) + return re.compile(rf"\..*({types}).*\.proto.*") + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._pattern = cls._build_pattern() + return cls._instance def __call__(self, file_name: str | Path, data: bytes) -> dict: return self._load(file_name, data) @@ -27,44 +49,28 @@ class ProtoDataLoader: file_name = str(file_name) document_type = self._match(file_name) - logger.info(f"Loading document type: {document_type}") - match document_type: - case self.DocumentType.STRUCTURE.value: - schema = DocumentStructure_pb2 - message = schema.DocumentStructure() - case self.DocumentType.TEXT.value: - schema = DocumentTextData_pb2 - message = schema.AllDocumentTextData() - case self.DocumentType.PAGES.value: - schema = DocumentPage_pb2 - message = schema.AllDocumentPages() - case self.DocumentType.POSITION.value: - schema = DocumentPositionData_pb2 - message = schema.AllDocumentPositionData() - case _: - raise ValueError(f"Document type {document_type} not supported") + if not document_type: + raise ValueError(f"Unknown document type: {file_name}, supported types: {self.DocumentType}") + logger.debug(f"Loading document type: {document_type}") + schema, _ = self.DocumentType[document_type].value + message = schema() message.ParseFromString(data) message_dict = MessageToDict(message) - unpacked_message = self._unpack(message_dict) - return unpacked_message + return self._unpack(message_dict) - def _build_pattern(self) -> str: - types = "|".join([doc_type.value for doc_type in self.DocumentType]) - pattern = r"\..*(" + types + r").*\.proto\..*" - return pattern - - def _match(self, file_name: str) -> str: - return re.search(self.pattern, file_name).group(1) + def _match(self, file_name: str) -> str | None: + match = self._pattern.search(file_name) + return match.group(1) if match else None def _unpack(self, message_dict: dict) -> list | dict: if len(message_dict) > 1: return message_dict - for key in message_dict: - if key in self.KEYS_TO_UNPACK: - logger.info(f"Unpacking key: {key}") + for key in self.KEYS_TO_UNPACK: + if key in message_dict: + logger.debug(f"Unpacking key: {key}") return message_dict[key] return message_dict diff --git a/pyinfra/storage/utils.py b/pyinfra/storage/utils.py index 36b5b81..f2fc5e4 100644 --- a/pyinfra/storage/utils.py +++ b/pyinfra/storage/utils.py @@ -6,6 +6,7 @@ from typing import Union from kn_utils.logging import logger from pydantic import BaseModel, ValidationError +from pyinfra.storage.proto_data_loader import ProtoDataLoader from pyinfra.storage.storages.storage import Storage @@ -80,7 +81,14 @@ def _download_single_file(file_path: str, storage: Storage) -> bytes: data = storage.get_object(file_path) data = gzip.decompress(data) if ".gz" in file_path else data - data = json.loads(data.decode("utf-8")) if ".json" in file_path else data + + if ".json" in file_path: + data = json.loads(data.decode("utf-8")) + elif ".proto" in file_path: + data = ProtoDataLoader()(file_path, data) + else: + pass # identity for other file types + logger.info(f"Downloaded {file_path} from storage.") return data diff --git a/tests/unit_test/data_loader_test.py b/tests/unit_test/proto_data_loader_test.py similarity index 67% rename from tests/unit_test/data_loader_test.py rename to tests/unit_test/proto_data_loader_test.py index 74248e4..3df9187 100644 --- a/tests/unit_test/data_loader_test.py +++ b/tests/unit_test/proto_data_loader_test.py @@ -44,6 +44,24 @@ def proto_data_loader(): return ProtoDataLoader() +@pytest.fixture +def should_match(): + return [ + "a.DOCUMENT_STRUCTURE.proto.gz", + "a.DOCUMENT_TEXT.proto.gz", + "a.DOCUMENT_PAGES.proto.gz", + "a.DOCUMENT_POSITION.proto.gz", + "b.DOCUMENT_STRUCTURE.proto", + "b.DOCUMENT_TEXT.proto", + "b.DOCUMENT_PAGES.proto", + "b.DOCUMENT_POSITION.proto", + "c.STRUCTURE.proto.gz", + "c.TEXT.proto.gz", + "c.PAGES.proto.gz", + "c.POSITION.proto.gz", + ] + + @pytest.mark.parametrize( "document_fixture", [ @@ -53,7 +71,7 @@ def proto_data_loader(): "document_position_document", ], ) -def test_proto_data_loader(document_fixture, request, proto_data_loader): +def test_proto_data_loader_end2end(document_fixture, request, proto_data_loader): file_path, data = request.getfixturevalue(document_fixture) data = gzip.decompress(data) loaded_data = proto_data_loader(file_path, data) @@ -61,3 +79,13 @@ def test_proto_data_loader(document_fixture, request, proto_data_loader): # TODO: Right now, we don't have access to proto-json pairs to compare the loaded data with the expected data. # If this becomes available, please update this test to compare the loaded data with the expected data. assert isinstance(loaded_data, dict) or isinstance(loaded_data, list) + + +def test_proto_data_loader_unknown_document_type(proto_data_loader): + with pytest.raises(ValueError): + proto_data_loader("unknown_document_type.proto", b"") + + +def test_proto_data_loader_file_name_matching(proto_data_loader, should_match): + for file_name in should_match: + assert proto_data_loader._match(file_name) is not None