From 506ed789f71780bb9dda18231dc6203763366b79 Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Tue, 13 Dec 2022 11:16:15 +0100 Subject: [PATCH] add explorative script for hierarichal layout parsing --- scripts/explore_aio_detection.py | 141 +++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 scripts/explore_aio_detection.py diff --git a/scripts/explore_aio_detection.py b/scripts/explore_aio_detection.py new file mode 100644 index 0000000..6034554 --- /dev/null +++ b/scripts/explore_aio_detection.py @@ -0,0 +1,141 @@ +import argparse +from dataclasses import dataclass, asdict, field +from operator import truth +from typing import List + +import cv2 +import numpy as np +from funcy import lfilter, lmap, lflatten + +from cv_analysis.table_parsing import parse_tables +from cv_analysis.utils.display import show_image_mpl +from pdf2img.conversion import convert_pages_to_images + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("pdf_path") + parser.add_argument("--index", "-i", type=int, default=0) + return parser.parse_args() + + +def load_image(pdf_path, index): + with open(pdf_path, "rb") as f: + pdf = f.read() + page = next(convert_pages_to_images(pdf, index=[index])) + return page.asarray() + + +@dataclass +class Node: + x0: int + y0: int + x1: int + y1: int + + types: List = field(default_factory=lambda: []) + children: List = field(default_factory=lambda: []) + + def has_children(self): + return truth(self.children) + + def asdict(self): + return asdict(self) + + +def make_child(xywh, parent, types=None): + x0, y0, w, h = xywh + x1, y1 = x0 + w, y0 + h + return Node(x0 + parent.x0, y0 + parent.y0, x1 + parent.x0, y1 + parent.y0, types) + + +def parse_basic_layout(image): + image = ~cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + _, image = cv2.threshold(image, 50, 255, cv2.THRESH_BINARY) + + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (30, 20)) + image = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel) + + contours, _ = cv2.findContours(image=image, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE) + bboxes = lmap(cv2.boundingRect, contours) + return bboxes + + +def parse_to_tree(image): + layout_bboxes = parse_basic_layout(image) + + root_node = Node(0, 0, image.shape[1], image.shape[0], ["root"]) + children = lmap(lambda c: make_child(c, root_node), layout_bboxes) + children = lmap(lambda c: classify_node_recursively(image, c), children) + root_node.children += children + + return root_node + + +def classify_node_recursively(image, node): + clip = image[node.y0 : node.y1, node.x0 : node.x1] + if "cell" in node.types: # remove possible surviving frame + clip = ~cv2.cvtColor(clip, cv2.COLOR_BGR2GRAY) + _, clip = cv2.threshold(clip, 50, 255, cv2.THRESH_BINARY) + # cv2.floodFill(clip, None, (0, 0), 255) + print(clip.shape) + show_image_mpl(clip) + + maybe_tables = parse_tables(clip) + if maybe_tables: + bboxes = lmap(lambda r: r.xywh(), maybe_tables) + children = lmap(lambda b: make_child(b, node, ["cell"]), bboxes) + children = lmap(lambda c: classify_node_recursively(image, c), children) + node.types.append("table") + node.children = children + return node + + maybe_texts = detect_text(clip) + if maybe_texts: + children = lmap(lambda b: make_child(b, node, ["text"]), maybe_texts) + node.children = children + return node #FIGURES + + return node + + +def detect_text(image): + if len(image.shape) > 2: + image = ~cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + _, image = cv2.threshold(image, 50, 255, cv2.THRESH_BINARY) + + kernel = cv2.getStructuringElement(cv2.MORPH_OPEN, (5, 5)) + clip = cv2.dilate(image, kernel, iterations=1) + + contours, _ = cv2.findContours(image=clip, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE) + + bboxes = lmap(cv2.boundingRect, contours) + text_max_height = 40 + text_max_width = 500 + bboxes = lfilter(lambda bbox: bbox[2] < text_max_width and bbox[3] < text_max_height, bboxes) + return bboxes + + +def draw_node(image, node): + cv2.rectangle(image, (node.x0, node.y0), (node.x1, node.y1), color=(0, 255, 0), thickness=1) + cv2.putText( + image, + str(node.types), + org=(node.x0, node.y0), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=0.5, + color=(0, 100, 100), + thickness=1, + ) + if node.has_children(): + for child in node.children: + draw_node(image, child) + + +if __name__ == "__main__": + args = parse_args() + image = load_image(args.pdf_path, args.index) + root_node = parse_to_tree(image) + draw_node(image, root_node) + show_image_mpl(image) + print(root_node.asdict())