[WIP] Refactoring meta-detection

This commit is contained in:
Matthias Bisping 2023-01-09 14:20:22 +01:00
parent 8327794685
commit 012e705e70
2 changed files with 30 additions and 47 deletions

View File

@ -3,14 +3,11 @@ from typing import Iterable
import cv2
import numpy as np
from funcy import lmap, compose, rcompose, first, lkeep
from funcy import compose, rcompose, first, lkeep
from cv_analysis.utils.connect_rects import connect_related_rects2
from cv_analysis.utils.conversion import box_to_rectangle, rectangle_to_box
from cv_analysis.utils.postprocessing import (
remove_included,
has_no_parent,
)
from cv_analysis.utils.connect_rects import connect_related_rectangles
from cv_analysis.utils.conversion import box_to_rectangle
from cv_analysis.utils.postprocessing import remove_included, has_no_parent
from cv_analysis.utils.rectangle import Rectangle
@ -18,9 +15,7 @@ def parse_layout(image: np.array):
rectangles = find_segments(image)
rectangles = remove_included(rectangles)
boxes = lmap(rectangle_to_box, rectangles)
boxes = connect_related_rects2(boxes)
rectangles = lmap(box_to_rectangle, boxes)
rectangles = connect_related_rectangles(rectangles)
rectangles = remove_included(rectangles)
return rectangles
@ -60,13 +55,18 @@ def find_contours(image):
def is_likely_segment(rect, min_area=100):
# FIXME: Parameterize via factory
return cv2.contourArea(rect, False) > min_area
def dilate_page_components(image):
# FIXME: Parameterize via factory
image = cv2.GaussianBlur(image, (7, 7), 0)
# FIXME: Parameterize via factory
thresh = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
# FIXME: Parameterize via factory
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
# FIXME: Parameterize via factory
return cv2.dilate(thresh, kernel, iterations=4)
@ -87,6 +87,7 @@ def normalize_to_gray_scale(image):
def threshold_image(image):
# FIXME: Parameterize via factory
_, image = cv2.threshold(image, 254, 255, cv2.THRESH_BINARY)
return image

View File

@ -1,7 +1,9 @@
from itertools import combinations, starmap, product
from typing import Iterable
from cv_analysis.utils.conversion import rectangle_to_box
from funcy import lfilter, lmap
from cv_analysis.utils.conversion import rectangle_to_box, box_to_rectangle
from cv_analysis.utils.rectangle import Rectangle
@ -38,9 +40,9 @@ def is_on_same_line(rect_pair):
)
def has_correct_position1(rect_pair):
x1, y1, w1, h1 = rect_pair[0]
x2, y2, w2, h2 = rect_pair[1]
def has_correct_position(alpha: Rectangle, beta: Rectangle):
x1, y1, w1, h1 = alpha
x2, y2, w2, h2 = beta
return any(
[
any(
@ -59,7 +61,7 @@ def has_correct_position1(rect_pair):
def is_related(rect_pair):
return (is_near_enough(rect_pair) and has_correct_position1(rect_pair)) or is_overlapping(rect_pair)
return (is_near_enough(rect_pair) and has_correct_position(*rect_pair)) or is_overlapping(rect_pair)
def fuse_rects(rect1, rect2):
@ -76,61 +78,41 @@ def fuse_rects(rect1, rect2):
return tuple(topleft + w + h)
def rects_not_the_same(r):
def rectangles_differ(r):
return r[0] != r[1]
def find_related_rects(rects):
rect_pairs = list(filter(is_related, combinations(rects, 2)))
rect_pairs = list(filter(rects_not_the_same, rect_pairs))
rect_pairs = lfilter(is_related, combinations(rects, 2))
rect_pairs = lfilter(rectangles_differ, rect_pairs)
if not rect_pairs:
return [], rects
rel_rects = list(set([rect for pair in rect_pairs for rect in pair]))
rel_rects = set([rect for pair in rect_pairs for rect in pair])
unrel_rects = [rect for rect in rects if rect not in rel_rects]
return rect_pairs, unrel_rects
def connect_related_rects(rects):
rects_to_connect, rects_new = find_related_rects(rects)
def connect_related_rectangles(rectangles: Iterable[Rectangle]):
boxes = lmap(rectangle_to_box, rectangles)
while len(rects_to_connect) > 0:
rects_fused = list(starmap(fuse_rects, rects_to_connect))
rects_fused = list(dict.fromkeys(rects_fused))
if len(rects_fused) == 1:
rects_new += rects_fused
rects_fused = []
rects_to_connect, connected_rects = find_related_rects(rects_fused)
rects_new += connected_rects
if len(rects_to_connect) > 1 and len(set(rects_to_connect)) == 1:
rects_new.append(rects_fused[0])
rects_to_connect = []
return rects_new
def connect_related_rects2(rects: Iterable[tuple]):
rects = list(rects)
current_idx = 0
while True:
if current_idx + 1 >= len(rects) or len(rects) <= 1:
if current_idx + 1 >= len(boxes) or len(boxes) <= 1:
break
merge_happened = False
current_rect = rects.pop(current_idx)
current_rect = boxes.pop(current_idx)
for idx, maybe_related_rect in enumerate(rects):
for idx, maybe_related_rect in enumerate(boxes):
if is_related((current_rect, maybe_related_rect)):
current_rect = fuse_rects(current_rect, maybe_related_rect)
rects.pop(idx)
boxes.pop(idx)
merge_happened = True
break
rects.insert(0, current_rect)
boxes.insert(0, current_rect)
if not merge_happened:
current_idx += 1
elif merge_happened:
current_idx = 0
return rects
return lmap(box_to_rectangle, boxes)