From ce9e92876c657f7aff05aee24c0e92b2a29db6da Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Mon, 11 Jul 2022 12:25:16 +0200 Subject: [PATCH] Pull request #16: Add table parsing fixtures Merge in RR/cv-analysis from add_table_parsing_fixtures to master Squashed commit of the following: commit cfc89b421b61082c8e92e1971c9d0bf4490fa07e Merge: a7ecb05 73c66a8 Author: Julius Unverfehrt Date: Mon Jul 11 12:19:01 2022 +0200 Merge branch 'master' of ssh://git.iqser.com:2222/rr/cv-analysis into add_table_parsing_fixtures commit a7ecb05b7d8327f0c7429180f63a380b61b06bc3 Author: Julius Unverfehrt Date: Mon Jul 11 12:02:07 2022 +0200 refactor commit 466f217e5a9ee5c54fd38c6acd28d54fc38ff9bb Author: llocarnini Date: Mon Jul 11 10:24:14 2022 +0200 deleted unused imports and unused lines of code commit c58955c8658d0631cdd1c24c8556d399e3fd9990 Author: llocarnini Date: Mon Jul 11 10:16:01 2022 +0200 black reformatted files commit f8bcb10a00ff7f0da49b80c1609b17997411985a Author: llocarnini Date: Tue Jul 5 15:15:00 2022 +0200 reformat files commit 432e8a569fd70bd0745ce0549c2bfd2f2e907763 Author: llocarnini Date: Tue Jul 5 15:08:22 2022 +0200 added better test for generic pages with table WIP as thicker lines create inconsistent results. added test for patchy tables which does not work yet commit 2aac9ebf5c76bd963f8c136fe5dd4c2d7681b469 Author: llocarnini Date: Mon Jul 4 16:56:29 2022 +0200 added new fixtures for table_parsing_test.py commit 37606cac0301b13e99be2c16d95867477f29e7c4 Author: llocarnini Date: Fri Jul 1 16:02:44 2022 +0200 added separate file for table parsing fixtures, where fixtures for generic tables were added. WIP tests for generic table fixtures --- .gitignore | 1 + .../figure_detection_pipeline.py | 10 +- cv_analysis/layout_parsing.py | 10 +- cv_analysis/redaction_detection.py | 14 +- cv_analysis/table_parsing.py | 89 +++---- cv_analysis/utils/deskew.py | 12 +- cv_analysis/utils/draw.py | 10 +- cv_analysis/utils/logging.py | 4 +- cv_analysis/utils/post_processing.py | 18 +- cv_analysis/utils/preprocessing.py | 8 +- cv_analysis/utils/structures.py | 10 +- cv_analysis/utils/test_metrics.py | 6 +- cv_analysis/utils/visual_logging.py | 4 +- data/.gitignore | 2 - data/pdfs_for_testing.dvc | 5 - data/pngs_for_testing.dvc | 5 - test/conftest.py | 1 + test/fixtures/figure_detection.py | 19 +- test/fixtures/server.py | 1 + test/fixtures/table_parsing.py | 252 ++++++++++++++++++ .../figure_detection_pipeline_test.py | 9 +- test/unit_tests/figure_detection/text_test.py | 31 ++- .../server/formatted_stream_fn_test.py | 4 +- test/unit_tests/table_parsing_test.py | 59 ++-- 24 files changed, 469 insertions(+), 115 deletions(-) delete mode 100644 data/pdfs_for_testing.dvc delete mode 100644 data/pngs_for_testing.dvc create mode 100644 test/fixtures/table_parsing.py diff --git a/.gitignore b/.gitignore index 6f60597..f4ebc2a 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ build_venv/ /cv_analysis/test/test_data/example_pages.json /data/metadata_testing_files.csv .coverage +/data/ diff --git a/cv_analysis/figure_detection/figure_detection_pipeline.py b/cv_analysis/figure_detection/figure_detection_pipeline.py index 9a98be7..f0a3b35 100644 --- a/cv_analysis/figure_detection/figure_detection_pipeline.py +++ b/cv_analysis/figure_detection/figure_detection_pipeline.py @@ -5,7 +5,11 @@ import numpy as np from cv_analysis.figure_detection.figures import detect_large_coherent_structures from cv_analysis.figure_detection.text import remove_primary_text_regions -from cv_analysis.utils.filters import is_large_enough, has_acceptable_format, is_not_too_large +from cv_analysis.utils.filters import ( + is_large_enough, + has_acceptable_format, + is_not_too_large, +) from cv_analysis.utils.post_processing import remove_included from cv_analysis.utils.structures import Rectangle @@ -13,7 +17,9 @@ from cv_analysis.utils.structures import Rectangle def make_figure_detection_pipeline(min_area=5000, max_width_to_height_ratio=6): def pipeline(image: np.array): max_area = image.shape[0] * image.shape[1] * 0.99 - filter_cnts = make_filter_likely_figures(min_area, max_area, max_width_to_height_ratio) + filter_cnts = make_filter_likely_figures( + min_area, max_area, max_width_to_height_ratio + ) image = remove_primary_text_regions(image) cnts = detect_large_coherent_structures(image) diff --git a/cv_analysis/layout_parsing.py b/cv_analysis/layout_parsing.py index 092c22a..3ffeecf 100644 --- a/cv_analysis/layout_parsing.py +++ b/cv_analysis/layout_parsing.py @@ -10,7 +10,11 @@ import numpy as np # from cv_analysis.utils.display import show_mpl # from cv_analysis.utils.draw import draw_rectangles from cv_analysis.utils.structures import Rectangle -from cv_analysis.utils.post_processing import remove_overlapping, remove_included, has_no_parent +from cv_analysis.utils.post_processing import ( + remove_overlapping, + remove_included, + has_no_parent, +) from cv_analysis.utils.visual_logging import vizlogger @@ -19,7 +23,9 @@ def is_likely_segment(rect, min_area=100): def find_segments(image): - contours, hierarchies = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + contours, hierarchies = cv2.findContours( + image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) mask1 = map(is_likely_segment, contours) mask2 = map(has_no_parent, hierarchies[0]) diff --git a/cv_analysis/redaction_detection.py b/cv_analysis/redaction_detection.py index a633bee..3c5bf5f 100644 --- a/cv_analysis/redaction_detection.py +++ b/cv_analysis/redaction_detection.py @@ -12,7 +12,9 @@ from cv_analysis.utils.visual_logging import vizlogger def is_likely_redaction(contour, hierarchy, min_area): - return is_filled(hierarchy) and is_boxy(contour) and is_large_enough(contour, min_area) + return ( + is_filled(hierarchy) and is_boxy(contour) and is_large_enough(contour, min_area) + ) def find_redactions(image: np.array, min_normalized_area=200000): @@ -29,11 +31,17 @@ def find_redactions(image: np.array, min_normalized_area=200000): thresh = cv2.threshold(blurred, 252, 255, cv2.THRESH_BINARY)[1] vizlogger.debug(blurred, "redactions04_threshold.png") - contours, hierarchies = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + contours, hierarchies = cv2.findContours( + thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE + ) try: contours = map( - first, starfilter(partial(is_likely_redaction, min_area=min_normalized_area), zip(contours, hierarchies[0])) + first, + starfilter( + partial(is_likely_redaction, min_area=min_normalized_area), + zip(contours, hierarchies[0]), + ), ) return list(contours) except: diff --git a/cv_analysis/table_parsing.py b/cv_analysis/table_parsing.py index 78ede82..52d5292 100644 --- a/cv_analysis/table_parsing.py +++ b/cv_analysis/table_parsing.py @@ -4,37 +4,27 @@ from operator import attrgetter import cv2 import numpy as np -# from pdf2image import pdf2image -# from cv_analysis.utils.display import show_mpl -# from cv_analysis.utils.draw import draw_rectangles +from funcy import lmap + from cv_analysis.utils.post_processing import xywh_to_vecs, xywh_to_vec_rect, adjacent1d -# from cv_analysis.utils.deskew import deskew_histbased, deskew -# from cv_analysis.utils.filters import is_large_enough from cv_analysis.utils.structures import Rectangle from cv_analysis.utils.visual_logging import vizlogger from cv_analysis.layout_parsing import parse_layout -def add_external_contours(image, img): - contours, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) +def add_external_contours(image, image_h_w_lines_only): + + contours, _ = cv2.findContours( + image_h_w_lines_only, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE + ) for cnt in contours: x, y, w, h = cv2.boundingRect(cnt) cv2.rectangle(image, (x, y), (x + w, y + h), 255, 1) - vizlogger.debug(image, "external_contours.png") + return image -def extend_lines(): - # TODO - pass - - -def make_table_block_mask(): - # TODO - pass - - def apply_motion_blur(image: np.array, angle, size=80): """Solidifies and slightly extends detected lines. @@ -51,7 +41,11 @@ def apply_motion_blur(image: np.array, angle, size=80): vizlogger.debug(k, "tables08_blur_kernel1.png") k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32) vizlogger.debug(k, "tables09_blur_kernel2.png") - k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size)) + k = cv2.warpAffine( + k, + cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), + (size, size), + ) vizlogger.debug(k, "tables10_blur_kernel3.png") k = k * (1.0 / np.sum(k)) vizlogger.debug(k, "tables11_blur_kernel4.png") @@ -74,34 +68,25 @@ def isolate_vertical_and_horizontal_components(img_bin): kernel_v = np.ones((line_min_width, 1), np.uint8) img_bin_h = cv2.morphologyEx(img_bin, cv2.MORPH_OPEN, kernel_h) - vizlogger.debug(img_bin_h, "tables01_isolate01_img_bin_h.png") img_bin_v = cv2.morphologyEx(img_bin, cv2.MORPH_OPEN, kernel_v) img_lines_raw = img_bin_v | img_bin_h - vizlogger.debug(img_lines_raw, "tables02_isolate02_img_bin_v.png") kernel_h = np.ones((1, 30), np.uint8) kernel_v = np.ones((30, 1), np.uint8) img_bin_h = cv2.dilate(img_bin_h, kernel_h, iterations=2) - vizlogger.debug(img_bin_h, "tables03_isolate03_dilate_h.png") img_bin_v = cv2.dilate(img_bin_v, kernel_v, iterations=2) - vizlogger.debug(img_bin_v | img_bin_h, "tables04_isolate04_dilate_v.png") img_bin_h = apply_motion_blur(img_bin_h, 0) - vizlogger.debug(img_bin_h, "tables09_isolate05_blur_h.png") img_bin_v = apply_motion_blur(img_bin_v, 90) - vizlogger.debug(img_bin_v | img_bin_h, "tables10_isolate06_blur_v.png") - img_bin_final = img_bin_h | img_bin_v - vizlogger.debug(img_bin_final, "tables11_isolate07_final.png") - - th1, img_bin_final = cv2.threshold(img_bin_final, 120, 255, cv2.THRESH_BINARY) - vizlogger.debug(img_bin_final, "tables10_isolate12_threshold.png") - img_bin_final = cv2.dilate(img_bin_final, np.ones((1, 1), np.uint8), iterations=1) - vizlogger.debug(img_bin_final, "tables11_isolate13_dilate.png") + img_bin_extended = img_bin_h | img_bin_v + th1, img_bin_extended = cv2.threshold(img_bin_extended, 120, 255, cv2.THRESH_BINARY) + img_bin_final = cv2.dilate( + img_bin_extended, np.ones((1, 1), np.uint8), iterations=1 + ) # add contours before lines are extended by blurring img_bin_final = add_external_contours(img_bin_final, img_lines_raw) - vizlogger.debug(img_bin_final, "tables11_isolate14_contours_added.png") return img_bin_final @@ -130,13 +115,15 @@ def has_table_shape(rects): def find_table_layout_boxes(image: np.array): - layout_boxes = parse_layout(image) - table_boxes = [] - for box in layout_boxes: + def is_large_enough(box): (x, y, w, h) = box if w * h >= 100000: - table_boxes.append(Rectangle.from_xywh(box)) - return table_boxes + return Rectangle.from_xywh(box) + + layout_boxes = parse_layout(image) + a = lmap(is_large_enough, layout_boxes) + print(a) + return lmap(is_large_enough, layout_boxes) def preprocess(image: np.array): @@ -145,6 +132,19 @@ def preprocess(image: np.array): return ~image +def turn_connected_components_into_rects(image): + def is_large_enough(stat): + x1, y1, w, h, area = stat + return area > 2000 and w > 35 and h > 25 + + _, _, stats, _ = cv2.connectedComponentsWithStats( + ~image, connectivity=8, ltype=cv2.CV_32S + ) + + stats = np.vstack(list(filter(is_large_enough, stats))) + return stats[:, :-1][2:] + + def parse_tables(image: np.array, show=False): """Runs the full table parsing process. @@ -155,21 +155,10 @@ def parse_tables(image: np.array, show=False): list: list of rectangles corresponding to table cells """ - def is_large_enough(stat): - x1, y1, w, h, area = stat - return area > 2000 and w > 35 and h > 25 - image = preprocess(image) - # table_layout_boxes = find_table_layout_boxes(image) - image = isolate_vertical_and_horizontal_components(image) - # image = add_external_contours(image, image) - vizlogger.debug(image, "external_contours_added.png") - _, _, stats, _ = cv2.connectedComponentsWithStats(~image, connectivity=8, ltype=cv2.CV_32S) - - stats = np.vstack(list(filter(is_large_enough, stats))) - rects = stats[:, :-1][2:] + rects = turn_connected_components_into_rects(image) return list(map(Rectangle.from_xywh, rects)) diff --git a/cv_analysis/utils/deskew.py b/cv_analysis/utils/deskew.py index 06ec870..98f3de3 100644 --- a/cv_analysis/utils/deskew.py +++ b/cv_analysis/utils/deskew.py @@ -9,7 +9,9 @@ def rotate_straight(im: np.array, skew_angle: int) -> np.array: h, w = im.shape[:2] center = (w // 2, h // 2) M = cv2.getRotationMatrix2D(center, skew_angle, 1.0) - rotated = cv2.warpAffine(im, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE) + rotated = cv2.warpAffine( + im, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE + ) return rotated @@ -68,7 +70,9 @@ def needs_deskew(page: np.array) -> bool: unrotated_score = np.mean(np.abs(split_rowmean_diff(page))) angles = [-CONFIG.deskew.test_delta, CONFIG.deskew.test_delta] - scores = [np.mean(np.abs(split_rowmean_diff(rotate(page, angle)))) for angle in angles] + scores = [ + np.mean(np.abs(split_rowmean_diff(rotate(page, angle)))) for angle in angles + ] print(unrotated_score, scores) return unrotated_score > min(scores) @@ -78,4 +82,6 @@ if CONFIG.deskew.function == "hist": elif CONFIG.deskew.function == "identity": deskew = lambda page: (page, None) else: - raise ValueError("'{CONFIG.deskew.function}' is not a valid parameter value for CONFIG.deskew.function") + raise ValueError( + "'{CONFIG.deskew.function}' is not a valid parameter value for CONFIG.deskew.function" + ) diff --git a/cv_analysis/utils/draw.py b/cv_analysis/utils/draw.py index 96d0b3f..411eebd 100644 --- a/cv_analysis/utils/draw.py +++ b/cv_analysis/utils/draw.py @@ -15,7 +15,15 @@ def draw_contours(image, contours, color=None, annotate=False): def draw_rectangles(image, rectangles, color=None, annotate=False): def annotate_rect(x, y, w, h): - cv2.putText(image, "+", (x + (w // 2) - 12, y + (h // 2) + 9), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + cv2.putText( + image, + "+", + (x + (w // 2) - 12, y + (h // 2) + 9), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (0, 255, 0), + 2, + ) image = copy_and_normalize_channels(image) diff --git a/cv_analysis/utils/logging.py b/cv_analysis/utils/logging.py index 51be0fb..6fc280f 100644 --- a/cv_analysis/utils/logging.py +++ b/cv_analysis/utils/logging.py @@ -8,7 +8,9 @@ from cv_analysis.config import CONFIG def make_logger_getter(): logger = logging.getLogger(__name__) logger.setLevel(logging.getLevelName(CONFIG.service.logging_level)) - formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s: %(message)s", datefmt="%d.%m.%Y - %H:%M:%S") + formatter = logging.Formatter( + fmt="%(asctime)s %(levelname)s: %(message)s", datefmt="%d.%m.%Y - %H:%M:%S" + ) ch = logging.StreamHandler(sys.stdout) ch.setLevel(logging.getLevelName(CONFIG.service.logging_level)) diff --git a/cv_analysis/utils/post_processing.py b/cv_analysis/utils/post_processing.py index 1749f2d..46da1dc 100644 --- a/cv_analysis/utils/post_processing.py +++ b/cv_analysis/utils/post_processing.py @@ -18,11 +18,21 @@ def remove_overlapping(rectangles): def remove_included(rectangles): def included(a, b): - return b.xmin >= a.xmin and b.ymin >= a.ymin and b.xmax <= a.xmax and b.ymax <= a.ymax + return ( + b.xmin >= a.xmin + and b.ymin >= a.ymin + and b.xmax <= a.xmax + and b.ymax <= a.ymax + ) def includes(a, b, tol=3): """does a include b?""" - return b.xmin + tol >= a.xmin and b.ymin + tol >= a.ymin and b.xmax - tol <= a.xmax and b.ymax - tol <= a.ymax + return ( + b.xmin + tol >= a.xmin + and b.ymin + tol >= a.ymin + and b.xmax - tol <= a.xmax + and b.ymax - tol <= a.ymax + ) def is_not_included(rect, rectangles): return not any(includes(r2, rect) for r2 in rectangles if not rect == r2) @@ -100,7 +110,9 @@ def __remove_isolated_sorted(rectangles): def remove_isolated(rectangles, input_sorted=False): - return (__remove_isolated_sorted if input_sorted else __remove_isolated_unsorted)(rectangles) + return (__remove_isolated_sorted if input_sorted else __remove_isolated_unsorted)( + rectangles + ) Rectangle = namedtuple("Rectangle", "xmin ymin xmax ymax") diff --git a/cv_analysis/utils/preprocessing.py b/cv_analysis/utils/preprocessing.py index 011d54e..d51139a 100644 --- a/cv_analysis/utils/preprocessing.py +++ b/cv_analysis/utils/preprocessing.py @@ -21,9 +21,13 @@ def open_pdf(pdf, first_page=0, last_page=None): if pdf.lower().endswith((".png", ".jpg", ".jpeg")): pages = [Image.open(pdf)] else: # assume pdf as default file type for a path argument - pages = pdf2image.convert_from_path(pdf, first_page=first_page, last_page=last_page) + pages = pdf2image.convert_from_path( + pdf, first_page=first_page, last_page=last_page + ) elif type(pdf) == bytes: - pages = pdf2image.convert_from_bytes(pdf, first_page=first_page, last_page=last_page) + pages = pdf2image.convert_from_bytes( + pdf, first_page=first_page, last_page=last_page + ) elif type(pdf) in {list, ndarray}: return pdf diff --git a/cv_analysis/utils/structures.py b/cv_analysis/utils/structures.py index adec723..076e1a3 100644 --- a/cv_analysis/utils/structures.py +++ b/cv_analysis/utils/structures.py @@ -7,6 +7,7 @@ from funcy import identity class Rectangle: def __init__(self, x1=None, y1=None, w=None, h=None, x2=None, y2=None, indent=4, format="xywh", discrete=True): make_discrete = int if discrete else identity + try: self.x1 = make_discrete(x1) self.y1 = make_discrete(y1) @@ -28,7 +29,14 @@ class Rectangle: return {"x1": self.x1, "y1": self.y1, "x2": self.x2, "y2": self.y2} def json_full(self): - return {"x1": self.x1, "y1": self.y1, "x2": self.x2, "y2": self.y2, "width": self.w, "height": self.h} + return { + "x1": self.x1, + "y1": self.y1, + "x2": self.x2, + "y2": self.y2, + "width": self.w, + "height": self.h, + } def json(self): json_func = {"xywh": self.json_xywh, "xyxy": self.json_xyxy}.get(self.format, self.json_full) diff --git a/cv_analysis/utils/test_metrics.py b/cv_analysis/utils/test_metrics.py index 8df3d00..fd0eca6 100644 --- a/cv_analysis/utils/test_metrics.py +++ b/cv_analysis/utils/test_metrics.py @@ -75,7 +75,11 @@ def compute_document_score(results_dict, annotation_dict): scores = [] for i in range(len(annotation_dict["pages"])): - scores.append(compute_page_iou(results_dict["pages"][i]["cells"], annotation_dict["pages"][i]["cells"])) + scores.append( + compute_page_iou( + results_dict["pages"][i]["cells"], annotation_dict["pages"][i]["cells"] + ) + ) scores = np.array(scores) doc_score = np.average(scores, weights=page_weights) diff --git a/cv_analysis/utils/visual_logging.py b/cv_analysis/utils/visual_logging.py index 1892805..983b546 100644 --- a/cv_analysis/utils/visual_logging.py +++ b/cv_analysis/utils/visual_logging.py @@ -36,4 +36,6 @@ class VisualLogger: return self.level == "ALL" -vizlogger = VisualLogger(CONFIG.visual_logging.level, CONFIG.visual_logging.output_folder) +vizlogger = VisualLogger( + CONFIG.visual_logging.level, CONFIG.visual_logging.output_folder +) diff --git a/data/.gitignore b/data/.gitignore index 7b38b1e..8f6ae38 100644 --- a/data/.gitignore +++ b/data/.gitignore @@ -1,7 +1,5 @@ /test_pdf.pdf -/pdfs_for_testing /figure_detection.png /layout_parsing.png /redaction_detection.png /table_parsing.png -/pngs_for_testing diff --git a/data/pdfs_for_testing.dvc b/data/pdfs_for_testing.dvc deleted file mode 100644 index e85e518..0000000 --- a/data/pdfs_for_testing.dvc +++ /dev/null @@ -1,5 +0,0 @@ -outs: -- md5: bb0ce084f7ca54583972da71cb87e22c.dir - size: 367181628 - nfiles: 28 - path: pdfs_for_testing diff --git a/data/pngs_for_testing.dvc b/data/pngs_for_testing.dvc deleted file mode 100644 index 630eab7..0000000 --- a/data/pngs_for_testing.dvc +++ /dev/null @@ -1,5 +0,0 @@ -outs: -- md5: 4fed91116111b47edf1c6f6a67eb84d3.dir - size: 58125058 - nfiles: 230 - path: pngs_for_testing diff --git a/test/conftest.py b/test/conftest.py index 6760193..8ccc497 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,4 +1,5 @@ pytest_plugins = [ + "test.fixtures.table_parsing", "test.fixtures.server", "test.fixtures.figure_detection", ] diff --git a/test/fixtures/figure_detection.py b/test/fixtures/figure_detection.py index 83aefa4..386e4a7 100644 --- a/test/fixtures/figure_detection.py +++ b/test/fixtures/figure_detection.py @@ -6,7 +6,9 @@ import pytest from PIL import Image from lorem_text import lorem from funcy import first -from cv_analysis.figure_detection.figure_detection_pipeline import make_figure_detection_pipeline +from cv_analysis.figure_detection.figure_detection_pipeline import ( + make_figure_detection_pipeline, +) from cv_analysis.utils.display import show_mpl @@ -29,11 +31,15 @@ def page_with_text(background, font_scale, font_style, text_types): if "body" in text_types: cursor = (image.shape[1] // 2, 70) image = paste_text(image, cursor, font_scale, font_style, y_stop=body_height) - cursor = (50, body_height+70) - image = paste_text(image, cursor, font_scale, font_style, y_stop=body_height*2) + cursor = (50, body_height + 70) + image = paste_text( + image, cursor, font_scale, font_style, y_stop=body_height * 2 + ) if "caption" in text_types: cursor = (image.shape[1] // 2, image.shape[0] - 100) - image = paste_text(image, cursor, font_scale, font_style, y_stop=body_height*3) + image = paste_text( + image, cursor, font_scale, font_style, y_stop=body_height * 3 + ) return image @@ -61,7 +67,9 @@ def paste_text(image: np.ndarray, cursor, font_scale, font_style, y_stop): def paste_text_at_cursor(x_start, y_start, y_stop): # TODO: adjust incorrect right margin text = lorem.paragraphs(1) * 200 - (dx, dy), base = cv2.getTextSize(text, fontFace=font_style, fontScale=font_scale, thickness=1) + (dx, dy), base = cv2.getTextSize( + text, fontFace=font_style, fontScale=font_scale, thickness=1 + ) dy += base # char_width = dx // len(text) text = textwrap.fill(text=text, width=(dx // page_width)) @@ -90,4 +98,3 @@ def paste_image(page_image, image, coords): image = Image.fromarray(image.astype("uint8")).convert("RGBA") page_image.paste(image, coords) return page_image - diff --git a/test/fixtures/server.py b/test/fixtures/server.py index 4c959e7..982b89a 100644 --- a/test/fixtures/server.py +++ b/test/fixtures/server.py @@ -47,6 +47,7 @@ def expected_analyse_metadata(operation, random_image_metadata_package, image_si result_metadata = {} if operation == "mock": + return {**metadata, **result_metadata} diff --git a/test/fixtures/table_parsing.py b/test/fixtures/table_parsing.py new file mode 100644 index 0000000..94e8cb0 --- /dev/null +++ b/test/fixtures/table_parsing.py @@ -0,0 +1,252 @@ +import json +from os.path import join +import cv2 +import pytest +from funcy import first + +from cv_analysis.locations import TEST_DATA_DIR +from cv_analysis.utils.draw import draw_rectangles +from cv_analysis.utils.preprocessing import open_pdf +from test.fixtures.figure_detection import paste_text + + +@pytest.fixture +def client_page_with_table(test_file_index): + img_path = join(TEST_DATA_DIR, f"test{test_file_index}.png") + return first(open_pdf(img_path)) + + +@pytest.fixture +def expected_table_annotation(test_file_index): + json_path = join(TEST_DATA_DIR, f"test{test_file_index}.json") + with open(json_path) as f: + return json.load(f) + + +@pytest.fixture +def page_with_table( + background, table_shape, table_style, n_tables, line_thickness, line_type +): + page = draw_table( + background, + (100, 100), + table_shape, + table_style, + line_thickness, + line_type=line_type, + ) + if n_tables == 2: + page = draw_table( + page, (200, 2000), table_shape, table_style, line_thickness, line_type + ) + return page + + +@pytest.fixture +def page_with_patchy_table(page_with_table, background_color): + page = page_with_table + page_width = 2480 + page_height = 3508 + x_start = 0 + y_start = 0 + for x in range(0, page_width, 325): + page = cv2.line( + page, + (x, y_start), + (x, page_height), + tuple(3 * [background_color]), + 2, + cv2.LINE_AA, + ) + for y in range(0, page_height, 515): + page = cv2.line( + page, + (x_start, y), + (page_width, y), + tuple(3 * [background_color]), + 1, + cv2.LINE_AA, + ) + return page + + +@pytest.fixture +def page_with_table_and_text(page_with_table): + return paste_text(page_with_table, (50, 1500), 1, cv2.FONT_HERSHEY_COMPLEX, 1700) + + +@pytest.fixture +def expected_gold_page_with_table(page_with_table, n_tables): + result = [ + (103, 103, 185, 198), + (291, 103, 185, 198), + (479, 103, 185, 198), + (667, 103, 185, 198), + (855, 103, 185, 198), + (1043, 103, 185, 198), + (1231, 103, 185, 198), + (1419, 103, 181, 198), + (103, 304, 185, 198), + (291, 304, 185, 198), + (479, 304, 185, 198), + (667, 304, 185, 198), + (855, 304, 185, 198), + (1043, 304, 185, 198), + (1231, 304, 185, 198), + (1419, 304, 181, 198), + (103, 505, 185, 198), + (291, 505, 185, 198), + (479, 505, 185, 198), + (667, 505, 185, 198), + (855, 505, 185, 198), + (1043, 505, 185, 198), + (1231, 505, 185, 198), + (1419, 505, 181, 198), + (103, 706, 185, 198), + (291, 706, 185, 198), + (479, 706, 185, 198), + (667, 706, 185, 198), + (855, 706, 185, 198), + (1043, 706, 185, 198), + (1231, 706, 185, 198), + (1419, 706, 181, 198), + (103, 907, 185, 193), + (291, 907, 185, 193), + (479, 907, 185, 193), + (667, 907, 185, 193), + (855, 907, 185, 193), + (1043, 907, 185, 193), + (1231, 907, 185, 193), + (1419, 907, 181, 193), + ] + if n_tables == 2: + result = [ + (103, 103, 185, 198), + (291, 103, 185, 198), + (479, 103, 185, 198), + (667, 103, 185, 198), + (855, 103, 185, 198), + (1043, 103, 185, 198), + (1231, 103, 185, 198), + (1419, 103, 181, 198), + (103, 304, 185, 198), + (291, 304, 185, 198), + (479, 304, 185, 198), + (667, 304, 185, 198), + (855, 304, 185, 198), + (1043, 304, 185, 198), + (1231, 304, 185, 198), + (1419, 304, 181, 198), + (103, 505, 185, 198), + (291, 505, 185, 198), + (479, 505, 185, 198), + (667, 505, 185, 198), + (855, 505, 185, 198), + (1043, 505, 185, 198), + (1231, 505, 185, 198), + (1419, 505, 181, 198), + (103, 706, 185, 198), + (291, 706, 185, 198), + (479, 706, 185, 198), + (667, 706, 185, 198), + (855, 706, 185, 198), + (1043, 706, 185, 198), + (1231, 706, 185, 198), + (1419, 706, 181, 198), + (103, 907, 185, 193), + (291, 907, 185, 193), + (479, 907, 185, 193), + (667, 907, 185, 193), + (855, 907, 185, 193), + (1043, 907, 185, 193), + (1231, 907, 185, 193), + (1419, 907, 181, 193), + (203, 2003, 186, 199), + (390, 2003, 187, 199), + (578, 2003, 187, 199), + (766, 2003, 187, 199), + (954, 2003, 187, 199), + (1142, 2003, 187, 199), + (1330, 2003, 187, 199), + (1518, 2003, 182, 199), + (203, 2203, 186, 200), + (390, 2203, 187, 200), + (578, 2203, 187, 200), + (766, 2203, 187, 200), + (954, 2203, 187, 200), + (1142, 2203, 187, 200), + (1330, 2203, 187, 200), + (1518, 2203, 182, 200), + (203, 2404, 186, 200), + (390, 2404, 187, 200), + (578, 2404, 187, 200), + (766, 2404, 187, 200), + (954, 2404, 187, 200), + (1142, 2404, 187, 200), + (1330, 2404, 187, 200), + (1518, 2404, 182, 200), + (203, 2605, 186, 200), + (390, 2605, 187, 200), + (578, 2605, 187, 200), + (766, 2605, 187, 200), + (954, 2605, 187, 200), + (1142, 2605, 187, 200), + (1330, 2605, 187, 200), + (1518, 2605, 182, 200), + (203, 2806, 186, 194), + (390, 2806, 187, 194), + (578, 2806, 187, 194), + (766, 2806, 187, 194), + (954, 2806, 187, 194), + (1142, 2806, 187, 194), + (1330, 2806, 187, 194), + (1518, 2806, 182, 194), + ] + return result + + +def draw_table( + page, table_position, table_shape, table_style, line_thickness, line_type +): + bbox_table = (*table_position, 1500, 1000) + page = draw_grid_lines( + page, + table_shape, + bbox_table, + table_style, + thickness=line_thickness, + line_type=line_type, + ) + if "closed" in table_style: + page = draw_rectangles(page, [bbox_table], (0, 0, 0)) + return page + + +def draw_grid_lines(image, table_shape, bbox, visible_lines, thickness, line_type): + x, y, w, h = bbox + n_rows, n_columns = table_shape + cell_width = bbox[2] // n_columns + 1 + cell_height = bbox[3] // n_rows + 1 + x_line, y_line = x + cell_width, y + cell_height + if "horizontal" in visible_lines: + for y_line in range(y_line, y + h, cell_height): + image = cv2.line( + image, + (x, y_line), + (x + w, y_line), + color=(0, 0, 0), + thickness=thickness, + lineType=line_type, + ) + + if "vertical" in visible_lines: + for x_line in range(x_line, x + w, cell_width): + image = cv2.line( + image, + (x_line, y), + (x_line, y + h), + color=(0, 0, 0), + thickness=thickness, + lineType=line_type, + ) + return image diff --git a/test/unit_tests/figure_detection/figure_detection_pipeline_test.py b/test/unit_tests/figure_detection/figure_detection_pipeline_test.py index 57b5526..b51e7a4 100644 --- a/test/unit_tests/figure_detection/figure_detection_pipeline_test.py +++ b/test/unit_tests/figure_detection/figure_detection_pipeline_test.py @@ -3,8 +3,6 @@ from math import prod import cv2 import pytest -from cv_analysis.utils.display import show_mpl -from cv_analysis.utils.draw import draw_rectangles from test.utils.utils import powerset @@ -43,7 +41,12 @@ class TestFindPrimaryTextRegions: @pytest.mark.parametrize("text_types", powerset(["body", "header", "caption"])) @pytest.mark.parametrize("error_tolerance", [0.9]) def test_page_with_images_and_text_yields_only_figures( - self, figure_detection_pipeline, page_with_images_and_text, image_size, n_images, error_tolerance + self, + figure_detection_pipeline, + page_with_images_and_text, + image_size, + n_images, + error_tolerance, ): results = list(figure_detection_pipeline(page_with_images_and_text)) diff --git a/test/unit_tests/figure_detection/text_test.py b/test/unit_tests/figure_detection/text_test.py index ce04285..6983d79 100644 --- a/test/unit_tests/figure_detection/text_test.py +++ b/test/unit_tests/figure_detection/text_test.py @@ -2,7 +2,10 @@ import cv2 import numpy as np import pytest -from cv_analysis.figure_detection.text import remove_primary_text_regions, apply_threshold_to_image +from cv_analysis.figure_detection.text import ( + remove_primary_text_regions, + apply_threshold_to_image, +) from cv_analysis.utils.display import show_mpl from test.utils.utils import powerset @@ -22,19 +25,33 @@ class TestFindPrimaryTextRegions: np.testing.assert_equal(result_page, apply_threshold_to_image(page_with_images)) @pytest.mark.parametrize("font_scale", [1, 1.5, 2]) - @pytest.mark.parametrize("font_style", [cv2.FONT_HERSHEY_SIMPLEX, cv2.FONT_HERSHEY_COMPLEX]) + @pytest.mark.parametrize( + "font_style", [cv2.FONT_HERSHEY_SIMPLEX, cv2.FONT_HERSHEY_COMPLEX] + ) @pytest.mark.parametrize("text_types", powerset(["body", "header", "caption"])) - def test_page_with_only_text_gets_text_removed(self, page_with_text, error_tolerance): + def test_page_with_only_text_gets_text_removed( + self, page_with_text, error_tolerance + ): result_page = remove_primary_text_regions(page_with_text) - relative_error = np.sum(result_page != apply_threshold_to_image(page_with_text)) / result_page.size + relative_error = ( + np.sum(result_page != apply_threshold_to_image(page_with_text)) + / result_page.size + ) assert relative_error <= error_tolerance @pytest.mark.parametrize("image_size", [(200, 200), (500, 500), (800, 800)]) @pytest.mark.parametrize("n_images", [1, 2]) @pytest.mark.parametrize("font_scale", [1, 1.5, 2]) - @pytest.mark.parametrize("font_style", [cv2.FONT_HERSHEY_SIMPLEX, cv2.FONT_HERSHEY_COMPLEX]) + @pytest.mark.parametrize( + "font_style", [cv2.FONT_HERSHEY_SIMPLEX, cv2.FONT_HERSHEY_COMPLEX] + ) @pytest.mark.parametrize("text_types", powerset(["body", "header", "caption"])) - def test_page_with_images_and_text_keeps_images(self, page_with_images_and_text, error_tolerance): + def test_page_with_images_and_text_keeps_images( + self, page_with_images_and_text, error_tolerance + ): result_page = remove_primary_text_regions(page_with_images_and_text) - relative_error = np.sum(result_page != apply_threshold_to_image(page_with_images_and_text)) / result_page.size + relative_error = ( + np.sum(result_page != apply_threshold_to_image(page_with_images_and_text)) + / result_page.size + ) assert relative_error <= error_tolerance diff --git a/test/unit_tests/server/formatted_stream_fn_test.py b/test/unit_tests/server/formatted_stream_fn_test.py index 88526fe..2cf1e3a 100644 --- a/test/unit_tests/server/formatted_stream_fn_test.py +++ b/test/unit_tests/server/formatted_stream_fn_test.py @@ -6,7 +6,9 @@ from cv_analysis.server.stream import make_streamable_analysis_fn @pytest.mark.parametrize("operation", ["mock"]) @pytest.mark.parametrize("image_size", [(200, 200), (500, 500), (800, 800)]) -def test_make_analysis_fn(analysis_fn_mock, random_image_metadata_package, expected_analyse_metadata): +def test_make_analysis_fn( + analysis_fn_mock, random_image_metadata_package, expected_analyse_metadata +): analyse = make_streamable_analysis_fn(analysis_fn_mock) results = first(analyse(random_image_metadata_package)) diff --git a/test/unit_tests/table_parsing_test.py b/test/unit_tests/table_parsing_test.py index 888aeec..0c14725 100644 --- a/test/unit_tests/table_parsing_test.py +++ b/test/unit_tests/table_parsing_test.py @@ -1,20 +1,18 @@ -import json -from os.path import join +from itertools import starmap +import cv2 import pytest -from funcy import first -from cv_analysis.locations import TEST_DATA_DIR from cv_analysis.table_parsing import parse_tables -from cv_analysis.utils.preprocessing import open_pdf from cv_analysis.utils.test_metrics import compute_document_score @pytest.mark.parametrize("score_threshold", [0.95]) @pytest.mark.parametrize("test_file_index", range(1, 11)) -def test_table_parsing(score_threshold, image_with_tables, expected_table_annotation, test_file_index): - - result = [x.json_xywh() for x in parse_tables(image_with_tables)] +def test_table_parsing_on_client_pages( + score_threshold, client_page_with_table, expected_table_annotation, test_file_index +): + result = [x.json_xywh() for x in parse_tables(client_page_with_table)] formatted_result = {"pages": [{"page": str(test_file_index), "cells": result}]} score = compute_document_score(formatted_result, expected_table_annotation) @@ -23,13 +21,42 @@ def test_table_parsing(score_threshold, image_with_tables, expected_table_annota @pytest.fixture -def image_with_tables(test_file_index): - img_path = join(TEST_DATA_DIR, f"test{test_file_index}.png") - return first(open_pdf(img_path)) +def error_tolerance(line_thickness): + return line_thickness * 7 -@pytest.fixture -def expected_table_annotation(test_file_index): - json_path = join(TEST_DATA_DIR, f"test{test_file_index}.json") - with open(json_path) as f: - return json.load(f) +@pytest.mark.parametrize("line_thickness", [1, 2, 3]) +@pytest.mark.parametrize("line_type", [cv2.LINE_4, cv2.LINE_AA, cv2.LINE_8]) +@pytest.mark.parametrize("table_style", ["closed horizontal vertical", "open horizontal vertical"]) +@pytest.mark.parametrize("n_tables", [1, 2]) +@pytest.mark.parametrize("background_color", [255, 220]) +@pytest.mark.parametrize("table_shape", [(5, 8)]) +def test_table_parsing_on_generic_pages(page_with_table, expected_gold_page_with_table, error_tolerance): + result = [x.xywh() for x in parse_tables(page_with_table)] + assert ( + result == expected_gold_page_with_table + or average_error(result, expected_gold_page_with_table) <= error_tolerance + ) + + +@pytest.mark.parametrize("line_thickness", [1, 2, 3]) +@pytest.mark.parametrize("line_type", [cv2.LINE_4, cv2.LINE_AA, cv2.LINE_8]) +@pytest.mark.parametrize("table_style", ["closed horizontal vertical", "open horizontal vertical"]) +@pytest.mark.parametrize("n_tables", [1, 2]) +@pytest.mark.parametrize("background_color", [255, 220]) +@pytest.mark.parametrize("table_shape", [(5, 8)]) +@pytest.mark.xfail +def test_bad_qual_table(page_with_patchy_table, expected_gold_page_with_table, error_tolerance): + result = [x.xywh() for x in parse_tables(page_with_patchy_table)] + assert ( + result == expected_gold_page_with_table + or average_error(result, expected_gold_page_with_table) <= error_tolerance + ) + + +def average_error(result, expected): + return sum(starmap(calc_rect_diff, zip(result, expected))) / len(expected) + + +def calc_rect_diff(rect1, rect2): + return sum(abs(c1 - c2) for c1, c2 in zip(rect1, rect2))