pyinfra/pyinfra/storage/proto_data_loader.py
2024-09-25 11:07:20 +02:00

128 lines
4.7 KiB
Python

import re
from enum import Enum
from pathlib import Path
from google.protobuf.json_format import MessageToDict
from kn_utils.logging import logger
from pyinfra.proto import (
DocumentPage_pb2,
DocumentPositionData_pb2,
DocumentStructure_pb2,
DocumentTextData_pb2,
)
class ProtoDataLoader:
"""Loads proto data from a file and returns it as a dictionary or list.
The loader is a singleton and should be used as a callable. The file name and byte data are passed as arguments.
The document type is determined based on the file name and the data is returned as a dictionary or list, depending
on the document type.
The DocumentType enum contains all supported document types and their corresponding proto schema.
KEYS_TO_UNPACK contains the keys that should be unpacked from the message dictionary. Keys are unpacked if the
message dictionary contains only one key. This behaviour is necessary since lists are wrapped in a dictionary.
"""
_instance = None
_pattern = None
class DocumentType(Enum):
STRUCTURE = (DocumentStructure_pb2.DocumentStructure, "DocumentStructure")
TEXT = (DocumentTextData_pb2.AllDocumentTextData, "AllDocumentTextData")
PAGES = (DocumentPage_pb2.AllDocumentPages, "AllDocumentPages")
POSITION = (DocumentPositionData_pb2.AllDocumentPositionData, "AllDocumentPositionData")
KEYS_TO_UNPACK = ["documentTextData", "documentPages", "documentPositionData"]
@classmethod
def _build_pattern(cls) -> re.Pattern:
types = "|".join([dt.name for dt in cls.DocumentType])
return re.compile(rf"\..*({types}).*\.proto.*")
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._pattern = cls._build_pattern()
return cls._instance
def __call__(self, file_name: str | Path, data: bytes) -> dict:
return self._load(file_name, data)
def _load(self, file_name: str | Path, data: bytes) -> dict | list:
file_name = str(file_name)
document_type = self._match(file_name)
if not document_type:
logger.error(f"Unknown document type: {file_name}, supported types: {self.DocumentType}")
return {}
logger.debug(f"Loading document type: {document_type}")
schema, _ = self.DocumentType[document_type].value
message = schema()
message.ParseFromString(data)
message_dict = MessageToDict(message, including_default_value_fields=True)
message_dict = convert_int64_fields(message_dict)
if document_type == "POSITION":
message_dict = transform_positions_to_list(message_dict)
return self._unpack(message_dict)
def _match(self, file_name: str) -> str | None:
match = self._pattern.search(file_name)
return match.group(1) if match else None
def _unpack(self, message_dict: dict) -> list | dict:
if len(message_dict) > 1:
return message_dict
for key in self.KEYS_TO_UNPACK:
if key in message_dict:
logger.debug(f"Unpacking key: {key}")
return message_dict[key]
return message_dict
def convert_int64_fields(obj):
# FIXME: find a more sophisticated way to convert int64 fields (defaults to str in python)
# we skip the following keys because the values are expected to be of type str
skip_keys = ["col", "row", "numberOfCols", "numberOfRows"]
if isinstance(obj, dict):
for key, value in obj.items():
if key in skip_keys:
continue
obj[key] = convert_int64_fields(value)
elif isinstance(obj, list):
return [convert_int64_fields(item) for item in obj]
elif isinstance(obj, str) and obj.isdigit():
return int(obj)
return obj
def transform_positions_to_list(obj: dict | list) -> dict:
"""Transforms the repeated fields 'positions' to a lists of lists of floats
as expected by DocumentReader.
Args:
obj (dict | list): Proto message dict
Returns:
dict: Proto message dict
"""
if isinstance(obj, dict):
# Check if 'positions' is in the dictionary and reshape it as list of lists of floats
if "positions" in obj and isinstance(obj["positions"], list):
obj["positions"] = [pos["value"] for pos in obj["positions"] if isinstance(pos, dict) and "value" in pos]
# Recursively apply to all nested dictionaries
for key, value in obj.items():
obj[key] = transform_positions_to_list(value)
elif isinstance(obj, list):
# Recursively apply to all items in the list
obj = [transform_positions_to_list(item) for item in obj]
return obj