diff --git a/cv_analysis/layout_parsing.py b/cv_analysis/layout_parsing.py index 6bb9d68..1a9718b 100644 --- a/cv_analysis/layout_parsing.py +++ b/cv_analysis/layout_parsing.py @@ -3,14 +3,11 @@ from typing import Iterable import cv2 import numpy as np -from funcy import lmap, compose, rcompose, first, lkeep +from funcy import compose, rcompose, first, lkeep -from cv_analysis.utils.connect_rects import connect_related_rects2 -from cv_analysis.utils.conversion import box_to_rectangle, rectangle_to_box -from cv_analysis.utils.postprocessing import ( - remove_included, - has_no_parent, -) +from cv_analysis.utils.connect_rects import connect_related_rectangles +from cv_analysis.utils.conversion import box_to_rectangle +from cv_analysis.utils.postprocessing import remove_included, has_no_parent from cv_analysis.utils.rectangle import Rectangle @@ -18,9 +15,7 @@ def parse_layout(image: np.array): rectangles = find_segments(image) rectangles = remove_included(rectangles) - boxes = lmap(rectangle_to_box, rectangles) - boxes = connect_related_rects2(boxes) - rectangles = lmap(box_to_rectangle, boxes) + rectangles = connect_related_rectangles(rectangles) rectangles = remove_included(rectangles) return rectangles @@ -60,13 +55,18 @@ def find_contours(image): def is_likely_segment(rect, min_area=100): + # FIXME: Parameterize via factory return cv2.contourArea(rect, False) > min_area def dilate_page_components(image): + # FIXME: Parameterize via factory image = cv2.GaussianBlur(image, (7, 7), 0) + # FIXME: Parameterize via factory thresh = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] + # FIXME: Parameterize via factory kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) + # FIXME: Parameterize via factory return cv2.dilate(thresh, kernel, iterations=4) @@ -87,6 +87,7 @@ def normalize_to_gray_scale(image): def threshold_image(image): + # FIXME: Parameterize via factory _, image = cv2.threshold(image, 254, 255, cv2.THRESH_BINARY) return image diff --git a/cv_analysis/utils/connect_rects.py b/cv_analysis/utils/connect_rects.py index 3c67777..abc1fd2 100644 --- a/cv_analysis/utils/connect_rects.py +++ b/cv_analysis/utils/connect_rects.py @@ -1,7 +1,9 @@ from itertools import combinations, starmap, product from typing import Iterable -from cv_analysis.utils.conversion import rectangle_to_box +from funcy import lfilter, lmap + +from cv_analysis.utils.conversion import rectangle_to_box, box_to_rectangle from cv_analysis.utils.rectangle import Rectangle @@ -38,9 +40,9 @@ def is_on_same_line(rect_pair): ) -def has_correct_position1(rect_pair): - x1, y1, w1, h1 = rect_pair[0] - x2, y2, w2, h2 = rect_pair[1] +def has_correct_position(alpha: Rectangle, beta: Rectangle): + x1, y1, w1, h1 = alpha + x2, y2, w2, h2 = beta return any( [ any( @@ -59,7 +61,7 @@ def has_correct_position1(rect_pair): def is_related(rect_pair): - return (is_near_enough(rect_pair) and has_correct_position1(rect_pair)) or is_overlapping(rect_pair) + return (is_near_enough(rect_pair) and has_correct_position(*rect_pair)) or is_overlapping(rect_pair) def fuse_rects(rect1, rect2): @@ -76,61 +78,41 @@ def fuse_rects(rect1, rect2): return tuple(topleft + w + h) -def rects_not_the_same(r): +def rectangles_differ(r): return r[0] != r[1] def find_related_rects(rects): - rect_pairs = list(filter(is_related, combinations(rects, 2))) - rect_pairs = list(filter(rects_not_the_same, rect_pairs)) + rect_pairs = lfilter(is_related, combinations(rects, 2)) + rect_pairs = lfilter(rectangles_differ, rect_pairs) if not rect_pairs: return [], rects - rel_rects = list(set([rect for pair in rect_pairs for rect in pair])) + rel_rects = set([rect for pair in rect_pairs for rect in pair]) unrel_rects = [rect for rect in rects if rect not in rel_rects] return rect_pairs, unrel_rects -def connect_related_rects(rects): - rects_to_connect, rects_new = find_related_rects(rects) +def connect_related_rectangles(rectangles: Iterable[Rectangle]): + boxes = lmap(rectangle_to_box, rectangles) - while len(rects_to_connect) > 0: - rects_fused = list(starmap(fuse_rects, rects_to_connect)) - rects_fused = list(dict.fromkeys(rects_fused)) - - if len(rects_fused) == 1: - rects_new += rects_fused - rects_fused = [] - - rects_to_connect, connected_rects = find_related_rects(rects_fused) - rects_new += connected_rects - - if len(rects_to_connect) > 1 and len(set(rects_to_connect)) == 1: - rects_new.append(rects_fused[0]) - rects_to_connect = [] - - return rects_new - - -def connect_related_rects2(rects: Iterable[tuple]): - rects = list(rects) current_idx = 0 while True: - if current_idx + 1 >= len(rects) or len(rects) <= 1: + if current_idx + 1 >= len(boxes) or len(boxes) <= 1: break merge_happened = False - current_rect = rects.pop(current_idx) + current_rect = boxes.pop(current_idx) - for idx, maybe_related_rect in enumerate(rects): + for idx, maybe_related_rect in enumerate(boxes): if is_related((current_rect, maybe_related_rect)): current_rect = fuse_rects(current_rect, maybe_related_rect) - rects.pop(idx) + boxes.pop(idx) merge_happened = True break - rects.insert(0, current_rect) + boxes.insert(0, current_rect) if not merge_happened: current_idx += 1 elif merge_happened: current_idx = 0 - return rects + return lmap(box_to_rectangle, boxes)