142 lines
4.1 KiB
Python
142 lines
4.1 KiB
Python
import argparse
|
|
from dataclasses import asdict, dataclass, field
|
|
from operator import truth
|
|
from typing import List
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from funcy import lfilter, lflatten, lmap
|
|
from pdf2img.conversion import convert_pages_to_images
|
|
|
|
from cv_analysis.table_parsing import parse_tables
|
|
from cv_analysis.utils.display import show_image_mpl
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("pdf_path")
|
|
parser.add_argument("--index", "-i", type=int, default=0)
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_image(pdf_path, index):
|
|
with open(pdf_path, "rb") as f:
|
|
pdf = f.read()
|
|
page = next(convert_pages_to_images(pdf, index=[index]))
|
|
return page.asarray()
|
|
|
|
|
|
@dataclass
|
|
class Node:
|
|
x0: int
|
|
y0: int
|
|
x1: int
|
|
y1: int
|
|
|
|
types: List = field(default_factory=lambda: [])
|
|
children: List = field(default_factory=lambda: [])
|
|
|
|
def has_children(self):
|
|
return truth(self.children)
|
|
|
|
def asdict(self):
|
|
return asdict(self)
|
|
|
|
|
|
def make_child(xywh, parent, types=None):
|
|
x0, y0, w, h = xywh
|
|
x1, y1 = x0 + w, y0 + h
|
|
return Node(x0 + parent.x0, y0 + parent.y0, x1 + parent.x0, y1 + parent.y0, types)
|
|
|
|
|
|
def parse_basic_layout(image):
|
|
image = ~cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
_, image = cv2.threshold(image, 50, 255, cv2.THRESH_BINARY)
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (30, 20))
|
|
image = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel)
|
|
|
|
contours, _ = cv2.findContours(image=image, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE)
|
|
bboxes = lmap(cv2.boundingRect, contours)
|
|
return bboxes
|
|
|
|
|
|
def parse_to_tree(image):
|
|
layout_bboxes = parse_basic_layout(image)
|
|
|
|
root_node = Node(0, 0, image.shape[1], image.shape[0], ["root"])
|
|
children = lmap(lambda c: make_child(c, root_node), layout_bboxes)
|
|
children = lmap(lambda c: classify_node_recursively(image, c), children)
|
|
root_node.children += children
|
|
|
|
return root_node
|
|
|
|
|
|
def classify_node_recursively(image, node):
|
|
clip = image[node.y0 : node.y1, node.x0 : node.x1]
|
|
if "cell" in node.types: # remove possible surviving frame
|
|
clip = ~cv2.cvtColor(clip, cv2.COLOR_BGR2GRAY)
|
|
_, clip = cv2.threshold(clip, 50, 255, cv2.THRESH_BINARY)
|
|
# cv2.floodFill(clip, None, (0, 0), 255)
|
|
print(clip.shape)
|
|
show_image_mpl(clip)
|
|
|
|
maybe_tables = parse_tables(clip)
|
|
if maybe_tables:
|
|
bboxes = lmap(lambda r: r.xywh(), maybe_tables)
|
|
children = lmap(lambda b: make_child(b, node, ["cell"]), bboxes)
|
|
children = lmap(lambda c: classify_node_recursively(image, c), children)
|
|
node.types.append("table")
|
|
node.children = children
|
|
return node
|
|
|
|
maybe_texts = detect_text(clip)
|
|
if maybe_texts:
|
|
children = lmap(lambda b: make_child(b, node, ["text"]), maybe_texts)
|
|
node.children = children
|
|
return node # FIGURES
|
|
|
|
return node
|
|
|
|
|
|
def detect_text(image):
|
|
if len(image.shape) > 2:
|
|
image = ~cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
_, image = cv2.threshold(image, 50, 255, cv2.THRESH_BINARY)
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_OPEN, (5, 5))
|
|
clip = cv2.dilate(image, kernel, iterations=1)
|
|
|
|
contours, _ = cv2.findContours(image=clip, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
bboxes = lmap(cv2.boundingRect, contours)
|
|
text_max_height = 40
|
|
text_max_width = 500
|
|
bboxes = lfilter(lambda bbox: bbox[2] < text_max_width and bbox[3] < text_max_height, bboxes)
|
|
return bboxes
|
|
|
|
|
|
def draw_node(image, node):
|
|
cv2.rectangle(image, (node.x0, node.y0), (node.x1, node.y1), color=(0, 255, 0), thickness=1)
|
|
cv2.putText(
|
|
image,
|
|
str(node.types),
|
|
org=(node.x0, node.y0),
|
|
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
|
fontScale=0.5,
|
|
color=(0, 100, 100),
|
|
thickness=1,
|
|
)
|
|
if node.has_children():
|
|
for child in node.children:
|
|
draw_node(image, child)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
image = load_image(args.pdf_path, args.index)
|
|
root_node = parse_to_tree(image)
|
|
draw_node(image, root_node)
|
|
show_image_mpl(image)
|
|
print(root_node.asdict())
|