cv-analysis-service/scripts/explore_aio_detection.py
2024-04-29 12:09:44 +02:00

142 lines
4.1 KiB
Python

import argparse
from dataclasses import asdict, dataclass, field
from operator import truth
from typing import List
import cv2
import numpy as np
from funcy import lfilter, lflatten, lmap
from pdf2img.conversion import convert_pages_to_images
from cv_analysis.table_parsing import parse_tables
from cv_analysis.utils.display import show_image_mpl
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("pdf_path")
parser.add_argument("--index", "-i", type=int, default=0)
return parser.parse_args()
def load_image(pdf_path, index):
with open(pdf_path, "rb") as f:
pdf = f.read()
page = next(convert_pages_to_images(pdf, index=[index]))
return page.asarray()
@dataclass
class Node:
x0: int
y0: int
x1: int
y1: int
types: List = field(default_factory=lambda: [])
children: List = field(default_factory=lambda: [])
def has_children(self):
return truth(self.children)
def asdict(self):
return asdict(self)
def make_child(xywh, parent, types=None):
x0, y0, w, h = xywh
x1, y1 = x0 + w, y0 + h
return Node(x0 + parent.x0, y0 + parent.y0, x1 + parent.x0, y1 + parent.y0, types)
def parse_basic_layout(image):
image = ~cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
_, image = cv2.threshold(image, 50, 255, cv2.THRESH_BINARY)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (30, 20))
image = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel)
contours, _ = cv2.findContours(image=image, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE)
bboxes = lmap(cv2.boundingRect, contours)
return bboxes
def parse_to_tree(image):
layout_bboxes = parse_basic_layout(image)
root_node = Node(0, 0, image.shape[1], image.shape[0], ["root"])
children = lmap(lambda c: make_child(c, root_node), layout_bboxes)
children = lmap(lambda c: classify_node_recursively(image, c), children)
root_node.children += children
return root_node
def classify_node_recursively(image, node):
clip = image[node.y0 : node.y1, node.x0 : node.x1]
if "cell" in node.types: # remove possible surviving frame
clip = ~cv2.cvtColor(clip, cv2.COLOR_BGR2GRAY)
_, clip = cv2.threshold(clip, 50, 255, cv2.THRESH_BINARY)
# cv2.floodFill(clip, None, (0, 0), 255)
print(clip.shape)
show_image_mpl(clip)
maybe_tables = parse_tables(clip)
if maybe_tables:
bboxes = lmap(lambda r: r.xywh(), maybe_tables)
children = lmap(lambda b: make_child(b, node, ["cell"]), bboxes)
children = lmap(lambda c: classify_node_recursively(image, c), children)
node.types.append("table")
node.children = children
return node
maybe_texts = detect_text(clip)
if maybe_texts:
children = lmap(lambda b: make_child(b, node, ["text"]), maybe_texts)
node.children = children
return node #FIGURES
return node
def detect_text(image):
if len(image.shape) > 2:
image = ~cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
_, image = cv2.threshold(image, 50, 255, cv2.THRESH_BINARY)
kernel = cv2.getStructuringElement(cv2.MORPH_OPEN, (5, 5))
clip = cv2.dilate(image, kernel, iterations=1)
contours, _ = cv2.findContours(image=clip, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE)
bboxes = lmap(cv2.boundingRect, contours)
text_max_height = 40
text_max_width = 500
bboxes = lfilter(lambda bbox: bbox[2] < text_max_width and bbox[3] < text_max_height, bboxes)
return bboxes
def draw_node(image, node):
cv2.rectangle(image, (node.x0, node.y0), (node.x1, node.y1), color=(0, 255, 0), thickness=1)
cv2.putText(
image,
str(node.types),
org=(node.x0, node.y0),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.5,
color=(0, 100, 100),
thickness=1,
)
if node.has_children():
for child in node.children:
draw_node(image, child)
if __name__ == "__main__":
args = parse_args()
image = load_image(args.pdf_path, args.index)
root_node = parse_to_tree(image)
draw_node(image, root_node)
show_image_mpl(image)
print(root_node.asdict())