128 lines
4.7 KiB
Python
128 lines
4.7 KiB
Python
import re
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
|
|
from google.protobuf.json_format import MessageToDict
|
|
from kn_utils.logging import logger
|
|
|
|
from pyinfra.proto import (
|
|
DocumentPage_pb2,
|
|
DocumentPositionData_pb2,
|
|
DocumentStructure_pb2,
|
|
DocumentTextData_pb2,
|
|
)
|
|
|
|
|
|
class ProtoDataLoader:
|
|
"""Loads proto data from a file and returns it as a dictionary or list.
|
|
|
|
The loader is a singleton and should be used as a callable. The file name and byte data are passed as arguments.
|
|
|
|
The document type is determined based on the file name and the data is returned as a dictionary or list, depending
|
|
on the document type.
|
|
The DocumentType enum contains all supported document types and their corresponding proto schema.
|
|
KEYS_TO_UNPACK contains the keys that should be unpacked from the message dictionary. Keys are unpacked if the
|
|
message dictionary contains only one key. This behaviour is necessary since lists are wrapped in a dictionary.
|
|
"""
|
|
|
|
_instance = None
|
|
_pattern = None
|
|
|
|
class DocumentType(Enum):
|
|
STRUCTURE = (DocumentStructure_pb2.DocumentStructure, "DocumentStructure")
|
|
TEXT = (DocumentTextData_pb2.AllDocumentTextData, "AllDocumentTextData")
|
|
PAGES = (DocumentPage_pb2.AllDocumentPages, "AllDocumentPages")
|
|
POSITION = (DocumentPositionData_pb2.AllDocumentPositionData, "AllDocumentPositionData")
|
|
|
|
KEYS_TO_UNPACK = ["documentTextData", "documentPages", "documentPositionData"]
|
|
|
|
@classmethod
|
|
def _build_pattern(cls) -> re.Pattern:
|
|
types = "|".join([dt.name for dt in cls.DocumentType])
|
|
return re.compile(rf"\..*({types}).*\.proto.*")
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._pattern = cls._build_pattern()
|
|
return cls._instance
|
|
|
|
def __call__(self, file_name: str | Path, data: bytes) -> dict:
|
|
return self._load(file_name, data)
|
|
|
|
def _load(self, file_name: str | Path, data: bytes) -> dict | list:
|
|
file_name = str(file_name)
|
|
document_type = self._match(file_name)
|
|
|
|
if not document_type:
|
|
logger.error(f"Unknown document type: {file_name}, supported types: {self.DocumentType}")
|
|
return {}
|
|
|
|
logger.debug(f"Loading document type: {document_type}")
|
|
schema, _ = self.DocumentType[document_type].value
|
|
message = schema()
|
|
message.ParseFromString(data)
|
|
message_dict = MessageToDict(message, including_default_value_fields=True)
|
|
message_dict = convert_int64_fields(message_dict)
|
|
if document_type == "POSITION":
|
|
message_dict = transform_positions_to_list(message_dict)
|
|
|
|
return self._unpack(message_dict)
|
|
|
|
def _match(self, file_name: str) -> str | None:
|
|
match = self._pattern.search(file_name)
|
|
return match.group(1) if match else None
|
|
|
|
def _unpack(self, message_dict: dict) -> list | dict:
|
|
if len(message_dict) > 1:
|
|
return message_dict
|
|
|
|
for key in self.KEYS_TO_UNPACK:
|
|
if key in message_dict:
|
|
logger.debug(f"Unpacking key: {key}")
|
|
return message_dict[key]
|
|
|
|
return message_dict
|
|
|
|
|
|
def convert_int64_fields(obj):
|
|
# FIXME: find a more sophisticated way to convert int64 fields (defaults to str in python)
|
|
|
|
# we skip the following keys because the values are expected to be of type str
|
|
skip_keys = ["col", "row", "numberOfCols", "numberOfRows"]
|
|
if isinstance(obj, dict):
|
|
for key, value in obj.items():
|
|
if key in skip_keys:
|
|
continue
|
|
obj[key] = convert_int64_fields(value)
|
|
elif isinstance(obj, list):
|
|
return [convert_int64_fields(item) for item in obj]
|
|
elif isinstance(obj, str) and obj.isdigit():
|
|
return int(obj)
|
|
return obj
|
|
|
|
|
|
def transform_positions_to_list(obj: dict | list) -> dict:
|
|
"""Transforms the repeated fields 'positions' to a lists of lists of floats
|
|
as expected by DocumentReader.
|
|
|
|
Args:
|
|
obj (dict | list): Proto message dict
|
|
|
|
Returns:
|
|
dict: Proto message dict
|
|
"""
|
|
if isinstance(obj, dict):
|
|
# Check if 'positions' is in the dictionary and reshape it as list of lists of floats
|
|
if "positions" in obj and isinstance(obj["positions"], list):
|
|
obj["positions"] = [pos["value"] for pos in obj["positions"] if isinstance(pos, dict) and "value" in pos]
|
|
|
|
# Recursively apply to all nested dictionaries
|
|
for key, value in obj.items():
|
|
obj[key] = transform_positions_to_list(value)
|
|
elif isinstance(obj, list):
|
|
# Recursively apply to all items in the list
|
|
obj = [transform_positions_to_list(item) for item in obj]
|
|
|
|
return obj
|