diff --git a/cv_analysis/figure_detection/figure_detection.py b/cv_analysis/figure_detection/figure_detection.py index 750bcf4..4807930 100644 --- a/cv_analysis/figure_detection/figure_detection.py +++ b/cv_analysis/figure_detection/figure_detection.py @@ -1,11 +1,11 @@ from functools import partial -import cv2 import numpy as np +from funcy import lmap from cv_analysis.figure_detection.figures import detect_large_coherent_structures from cv_analysis.figure_detection.text import remove_primary_text_regions -from cv_analysis.utils.conversion import box_to_rectangle +from cv_analysis.utils.conversion import contour_to_rectangle from cv_analysis.utils.filters import ( is_large_enough, has_acceptable_format, @@ -24,8 +24,7 @@ def detect_figures(image: np.array): contours = detect_large_coherent_structures(image) contours = filter(figure_filter, contours) - boxes = map(cv2.boundingRect, contours) - rectangles = map(box_to_rectangle, boxes) + rectangles = lmap(contour_to_rectangle, contours) rectangles = remove_included(rectangles) return rectangles @@ -33,7 +32,7 @@ def detect_figures(image: np.array): def is_likely_figure(min_area, max_area, max_width_to_height_ratio, contours): return ( - is_small_enough(contours, max_area) - and is_large_enough(contours, min_area) - and has_acceptable_format(contours, max_width_to_height_ratio) + is_small_enough(contours, max_area) + and is_large_enough(contours, min_area) + and has_acceptable_format(contours, max_width_to_height_ratio) ) diff --git a/cv_analysis/figure_detection/figures.py b/cv_analysis/figure_detection/figures.py index 18a5d16..b4897a7 100644 --- a/cv_analysis/figure_detection/figures.py +++ b/cv_analysis/figure_detection/figures.py @@ -1,6 +1,8 @@ import cv2 import numpy as np +from cv_analysis.utils.common import find_contours + def detect_large_coherent_structures(image: np.array): """Detects large coherent structures on an image. @@ -20,6 +22,6 @@ def detect_large_coherent_structures(image: np.array): close_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (20, 20)) close = cv2.morphologyEx(dilate, cv2.MORPH_CLOSE, close_kernel, iterations=1) - cnts, _ = cv2.findContours(close, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + contours, _ = find_contours(close) - return cnts + return contours diff --git a/cv_analysis/utils/common.py b/cv_analysis/utils/common.py index 1d71794..b7bc067 100644 --- a/cv_analysis/utils/common.py +++ b/cv_analysis/utils/common.py @@ -4,4 +4,4 @@ from funcy import first def find_contours(image): contours, hierarchies = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - return contours, first(hierarchies) + return contours, first(hierarchies) if hierarchies is not None else None diff --git a/cv_analysis/utils/conversion.py b/cv_analysis/utils/conversion.py index e03e3bc..3b0e100 100644 --- a/cv_analysis/utils/conversion.py +++ b/cv_analysis/utils/conversion.py @@ -1,8 +1,14 @@ import json +import cv2 + from cv_analysis.utils.rectangle import Rectangle +def contour_to_rectangle(contour): + return box_to_rectangle(cv2.boundingRect(contour)) + + def box_to_rectangle(box): x, y, w, h = box return Rectangle(x, y, x + w, y + h) diff --git a/cv_analysis/utils/display.py b/cv_analysis/utils/display.py index 0d3f2a6..ca6781f 100644 --- a/cv_analysis/utils/display.py +++ b/cv_analysis/utils/display.py @@ -1,4 +1,5 @@ import cv2 +from PIL import Image from matplotlib import pyplot as plt @@ -22,11 +23,15 @@ def show_image_mpl(image): plt.show() -def show_image(image, backend="m"): - if backend.startswith("m"): +def show_image(image, backend="mpl"): + if backend == "mpl": show_image_mpl(image) - else: + elif backend == "cv2": show_image_cv2(image) + elif backend == "pil": + Image.fromarray(image).show() + else: + raise ValueError(f"Unknown backend: {backend}") def save_image(image, path): diff --git a/cv_analysis/utils/drawing.py b/cv_analysis/utils/drawing.py index c07c7e6..9980a95 100644 --- a/cv_analysis/utils/drawing.py +++ b/cv_analysis/utils/drawing.py @@ -7,8 +7,8 @@ def draw_contours(image, contours): image = copy_and_normalize_channels(image) - for cont in contours: - cv2.drawContours(image, cont, -1, (0, 255, 0), 4) + for contour in contours: + cv2.drawContours(image, contour, -1, (0, 255, 0), 4) return image diff --git a/scripts/annotate.py b/scripts/annotate.py index 1056312..56b3597 100644 --- a/scripts/annotate.py +++ b/scripts/annotate.py @@ -69,4 +69,7 @@ def main(args): if __name__ == "__main__": - main(parse_args()) + try: + main(parse_args()) + except KeyboardInterrupt: + pass