Refactoring

Various
This commit is contained in:
Matthias Bisping 2023-01-09 11:21:42 +01:00
parent 06d6863cc5
commit 94e9210faf
8 changed files with 170 additions and 101 deletions

View File

@ -1,87 +1,100 @@
import itertools
from functools import reduce
from itertools import compress
from itertools import starmap
from operator import __and__
from typing import Iterable
import cv2
import numpy as np
from funcy import lmap, compose
from cv_analysis.utils.connect_rects import connect_related_rects2
from cv_analysis.utils.rectangle import Rectangle
from cv_analysis.utils.conversion import box_to_rectangle, rectangle_to_box
from cv_analysis.utils.postprocessing import (
remove_overlapping,
remove_included,
has_no_parent,
)
from cv_analysis.utils.visual_logging import vizlogger
#could be dynamic parameter is the scan is noisy
def is_likely_segment(rect, min_area=100):
return cv2.contourArea(rect, False) > min_area
from cv_analysis.utils.rectangle import Rectangle
def parse_layout(image: np.array):
original = image.copy()
image = normalize_to_gray_scale(image)
image = dilate_page_components(image)
rectangles = find_segments(image)
rectangles = meta_detection(original, rectangles)
rectangles = lmap(box_to_rectangle, rectangles)
rectangles = remove_included(rectangles)
rectangles = lmap(rectangle_to_box, rectangles)
rectangles = connect_related_rects2(rectangles)
rectangles = lmap(box_to_rectangle, rectangles)
rectangles = remove_included(rectangles)
return rectangles
def find_segments(image):
contours, hierarchies = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
mask1 = map(is_likely_segment, contours)
mask2 = map(has_no_parent, hierarchies[0])
mask = starmap(__and__, zip(mask1, mask2))
mask = map(__and__, mask1, mask2)
contours = compress(contours, mask)
rectangles = (cv2.boundingRect(c) for c in contours)
rectangles = lmap(compose(box_to_rectangle, cv2.boundingRect), contours)
return rectangles
def is_likely_segment(rect, min_area=100):
return cv2.contourArea(rect, False) > min_area
def dilate_page_components(image):
#if text is detected in words make kernel bigger
image = cv2.GaussianBlur(image, (7, 7), 0)
thresh = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
return cv2.dilate(thresh, kernel, iterations=4)
def meta_detection(image: np.ndarray, rectangles: Iterable[Rectangle]):
"""Given a list of previously detected segments, rerun the detection algorithm. Heuristically this improves the
quality of the detection.
"""
image = fill_rectangles(image, rectangles)
image = threshold_image(image)
image = invert_image(image)
image = normalize_to_gray_scale(image)
rectangles = find_segments(image)
return rectangles
def normalize_to_gray_scale(image):
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
return image
def threshold_image(image):
_, image = cv2.threshold(image, 254, 255, cv2.THRESH_BINARY)
return image
def invert_image(image):
return ~image
def fill_rectangles(image, rectangles):
image = reduce(fill_in_component_area, rectangles, image)
return image
def fill_in_component_area(image, rect):
x, y, w, h = rect
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 0), -1)
cv2.rectangle(image, (x, y), (x + w, y + h), (255, 255, 255), 7)
_, image = cv2.threshold(image, 254, 255, cv2.THRESH_BINARY)
return ~image
def parse_layout(image: np.array):
image = image.copy()
image_ = image.copy()
if len(image_.shape) > 2:
image_ = cv2.cvtColor(image_, cv2.COLOR_BGR2GRAY)
dilate = dilate_page_components(image_)
# show_mpl(dilate)
rects = list(find_segments(dilate))
# -> Run meta detection on the previous detections TODO: refactor
for rect in rects:
x, y, w, h = rect
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 0), -1)
cv2.rectangle(image, (x, y), (x + w, y + h), (255, 255, 255), 7)
# show_mpl(image)
_, image = cv2.threshold(image, 254, 255, cv2.THRESH_BINARY)
image = ~image
# show_mpl(image)
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
rects = find_segments(image)
# <- End of meta detection
rects = list(map(Rectangle.from_xywh, rects))
rects = remove_included(rects)
rects = map(lambda r: r.xywh(), rects)
rects = connect_related_rects2(rects)
rects = list(map(Rectangle.from_xywh, rects))
rects = remove_included(rects)
return rects
return image

View File

@ -1,6 +1,9 @@
from itertools import combinations, starmap, product
from typing import Iterable
from cv_analysis.utils.conversion import rectangle_to_box
from cv_analysis.utils.rectangle import Rectangle
def is_near_enough(rect_pair, max_gap=14):
x1, y1, w1, h1 = rect_pair[0]
@ -117,6 +120,7 @@ def connect_related_rects2(rects: Iterable[tuple]):
break
merge_happened = False
current_rect = rects.pop(current_idx)
for idx, maybe_related_rect in enumerate(rects):
if is_related((current_rect, maybe_related_rect)):
current_rect = fuse_rects(current_rect, maybe_related_rect)

View File

@ -1,6 +1,28 @@
import json
from cv_analysis.utils.rectangle import Rectangle
def box_to_rectangle(box):
x, y, w, h = box
return Rectangle(x, y, x + w, y + h)
def rectangle_to_box(rectangle):
return [rectangle.x1, rectangle.y1, rectangle.width, rectangle.height]
class RectangleJSONEncoder(json.JSONEncoder):
def __init__(self, *args, **kwargs):
json.JSONEncoder.__init__(self, *args, **kwargs)
self._replacement_map = {}
def default(self, o):
if isinstance(o, Rectangle):
return {"x1": o.x1, "x2": o.x2, "y1": o.y1, "y2": o.y2}
else:
return json.JSONEncoder.default(self, o)
def encode(self, o):
result = json.JSONEncoder.encode(self, o)
return result

View File

@ -5,22 +5,24 @@ from PIL import Image
from cv_analysis.utils.preprocessing import preprocess_page_array
def open_pdf(pdf, first_page=0, last_page=None):
def open_analysis_input_file(path_or_bytes, first_page=1, last_page=None):
first_page += 1
last_page = None if last_page is None else last_page + 1
assert first_page > 0, "Page numbers are 1-based."
assert last_page is None or last_page >= first_page, "last_page must be greater than or equal to first_page."
if type(pdf) == str:
if pdf.lower().endswith((".png", ".jpg", ".jpeg")):
pages = [Image.open(pdf)]
elif pdf.lower().endswith(".pdf"):
pages = pdf2image.convert_from_path(pdf, first_page=first_page, last_page=last_page)
last_page = last_page or first_page
if type(path_or_bytes) == str:
if path_or_bytes.lower().endswith((".png", ".jpg", ".jpeg")):
pages = [Image.open(path_or_bytes)]
elif path_or_bytes.lower().endswith(".pdf"):
pages = pdf2image.convert_from_path(path_or_bytes, first_page=first_page, last_page=last_page)
else:
raise IOError("Invalid file extension. Accepted filetypes:\n\t.png\n\t.jpg\n\t.jpeg\n\t.pdf")
elif type(pdf) == bytes:
pages = pdf2image.convert_from_bytes(pdf, first_page=first_page, last_page=last_page)
elif type(pdf) in {list, ndarray}:
return pdf
raise IOError("Invalid file extension. Accepted filetypes: .png, .jpg, .jpeg, .pdf")
elif type(path_or_bytes) == bytes:
pages = pdf2image.convert_from_bytes(path_or_bytes, first_page=first_page, last_page=last_page)
elif type(path_or_bytes) in {list, ndarray}:
return path_or_bytes
pages = [preprocess_page_array(array(p)) for p in pages]

View File

@ -47,6 +47,12 @@ class Rectangle:
def __hash__(self):
return hash((self.x1, self.y1, self.x2, self.y2))
def __iter__(self):
yield self.x1
yield self.y1
yield self.width
yield self.height
def area(self):
"""Calculates the area of this rectangle."""
return area(self)

View File

@ -1,50 +1,72 @@
"""
Usage:
python scripts/annotate.py /home/iriley/Documents/pdf/scanned/10.pdf 5 --type table --show
python scripts/annotate.py /home/iriley/Documents/pdf/scanned/10.pdf 5 --type redaction --show
python scripts/annotate.py /home/iriley/Documents/pdf/scanned/10.pdf 5 --type layout --show
python scripts/annotate.py /home/iriley/Documents/pdf/scanned/10.pdf 5 --type figure --show
"""
import argparse
import loguru
from cv_analysis.figure_detection.figure_detection import detect_figures
from cv_analysis.layout_parsing import parse_layout
from cv_analysis.redaction_detection import find_redactions
from cv_analysis.table_parsing import parse_tables
from cv_analysis.utils.display import show_image
from cv_analysis.utils.draw import draw_contours, draw_rectangles
from cv_analysis.utils.open_pdf import open_pdf
from cv_analysis.utils.visual_logging import vizlogger
from cv_analysis.utils.open_pdf import open_analysis_input_file
def parse_args():
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
description="Annotate PDF pages with detected elements. Specified pages form a closed interval and are 1-based."
)
parser.add_argument("pdf_path")
parser.add_argument("--page_index", type=int, default=0)
parser.add_argument("--type", choices=["table", "redaction", "layout", "figure"], default="table")
parser.add_argument("--show", action="store_true", default=False)
parser.add_argument(
"--first_page",
"-f",
type=int,
default=1,
)
parser.add_argument(
"-last_page",
"-l",
help="if not specified, defaults to the value of the first page specified",
type=int,
default=None,
)
parser.add_argument(
"--type",
"-t",
help="element type to look for and analyze",
choices=["table", "redaction", "layout", "figure"],
default="table",
)
parser.add_argument("--page", "-p", type=int, default=1)
args = parser.parse_args()
return args
def annotate_page(page_image, analysis_function, drawing_function, name="tmp.png", show=True):
result = analysis_function(page_image)
page_image = drawing_function(page_image, result)
vizlogger.debug(page_image, name)
def annotate_page(page_image, analysis_fn, draw_fn):
result = analysis_fn(page_image)
page_image = draw_fn(page_image, result)
show_image(page_image)
if __name__ == "__main__":
args = parse_args()
page = open_pdf(args.pdf_path, first_page=args.page_index, last_page=args.page_index)[0]
name = f"{args.type}_final_result.png"
draw = draw_rectangles
if args.type == "table":
from cv_analysis.table_parsing import parse_tables as analyze
elif args.type == "redaction":
from cv_analysis.redaction_detection import find_redactions as analyze
def get_analysis_and_draw_fn_for_type(element_type):
analysis_fn, draw_fn = {
"table": (parse_tables, draw_rectangles),
"redaction": (find_redactions, draw_contours),
"layout": (parse_layout, draw_rectangles),
"figure": (detect_figures, draw_rectangles),
}[element_type]
draw = draw_contours
elif args.type == "layout":
from cv_analysis.layout_parsing import parse_layout as analyze
elif args.type == "figure":
from cv_analysis.figure_detection.figure_detection import detect_figures
analyze = detect_figures
annotate_page(page, analyze, draw, name=name, show=args.show)
return analysis_fn, draw_fn
def main(args):
loguru.logger.info(f"Annotating {args.type}s in {args.pdf_path}...")
pages = open_analysis_input_file(args.pdf_path, first_page=args.first_page, last_page=args.last_page)
for page in pages:
analysis_fn, draw_fn = get_analysis_and_draw_fn_for_type(args.type)
annotate_page(page, analysis_fn, draw_fn)
if __name__ == "__main__":
main(parse_args())

View File

@ -10,7 +10,7 @@ from loguru import logger
from cv_analysis.config import get_config
from cv_analysis.locations import REPO_ROOT_PATH, TEST_DATA_DVC
from cv_analysis.utils.draw import draw_rectangles
from cv_analysis.utils.open_pdf import open_pdf
from cv_analysis.utils.open_pdf import open_analysis_input_file
from test.fixtures.figure_detection import paste_text
CV_CONFIG = get_config()
@ -19,7 +19,7 @@ CV_CONFIG = get_config()
@pytest.fixture
def client_page_with_table(test_file_index, dvc_test_data):
img_path = join(CV_CONFIG.test_data_dir, f"test{test_file_index}.png")
return first(open_pdf(img_path))
return first(open_analysis_input_file(img_path))
@pytest.fixture(scope="session")

View File