From dbc6d345f074e538948e2c4f94ebed8a5ef520bc Mon Sep 17 00:00:00 2001 From: Isaac Riley Date: Wed, 20 Jul 2022 16:32:42 +0200 Subject: [PATCH] removed PIL from production code, now inly in scripts --- .../figure_detection_pipeline.py | 4 +- cv_analysis/layout_parsing.py | 26 +----- cv_analysis/redaction_detection.py | 23 +---- cv_analysis/server/stream.py | 4 +- cv_analysis/table_parsing.py | 14 +-- cv_analysis/utils/deskew.py | 87 ------------------- cv_analysis/utils/display.py | 36 +++++--- cv_analysis/utils/logging.py | 4 +- cv_analysis/utils/open_pdf.py | 27 ++++++ cv_analysis/utils/post_processing.py | 18 +--- cv_analysis/utils/preprocessing.py | 46 ++++------ cv_analysis/utils/test_metrics.py | 6 +- cv_analysis/utils/visual_logging.py | 8 +- scripts/annotate.py | 7 +- scripts/deskew_demo.py | 50 ----------- scripts/pyinfra_mock.py | 11 +-- test/fixtures/figure_detection.py | 22 ++--- test/fixtures/server.py | 4 +- test/fixtures/table_parsing.py | 14 +-- test/unit_tests/figure_detection/text_test.py | 28 ++---- .../server/formatted_stream_fn_test.py | 4 +- 21 files changed, 113 insertions(+), 330 deletions(-) delete mode 100644 cv_analysis/utils/deskew.py create mode 100644 cv_analysis/utils/open_pdf.py delete mode 100644 scripts/deskew_demo.py diff --git a/cv_analysis/figure_detection/figure_detection_pipeline.py b/cv_analysis/figure_detection/figure_detection_pipeline.py index f0a3b35..1a374f1 100644 --- a/cv_analysis/figure_detection/figure_detection_pipeline.py +++ b/cv_analysis/figure_detection/figure_detection_pipeline.py @@ -17,9 +17,7 @@ from cv_analysis.utils.structures import Rectangle def make_figure_detection_pipeline(min_area=5000, max_width_to_height_ratio=6): def pipeline(image: np.array): max_area = image.shape[0] * image.shape[1] * 0.99 - filter_cnts = make_filter_likely_figures( - min_area, max_area, max_width_to_height_ratio - ) + filter_cnts = make_filter_likely_figures(min_area, max_area, max_width_to_height_ratio) image = remove_primary_text_regions(image) cnts = detect_large_coherent_structures(image) diff --git a/cv_analysis/layout_parsing.py b/cv_analysis/layout_parsing.py index 3ffeecf..1e6171e 100644 --- a/cv_analysis/layout_parsing.py +++ b/cv_analysis/layout_parsing.py @@ -5,10 +5,6 @@ from operator import __and__ import cv2 import numpy as np -# from pdf2image import pdf2image - -# from cv_analysis.utils.display import show_mpl -# from cv_analysis.utils.draw import draw_rectangles from cv_analysis.utils.structures import Rectangle from cv_analysis.utils.post_processing import ( remove_overlapping, @@ -23,9 +19,7 @@ def is_likely_segment(rect, min_area=100): def find_segments(image): - contours, hierarchies = cv2.findContours( - image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE - ) + contours, hierarchies = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) mask1 = map(is_likely_segment, contours) mask2 = map(has_no_parent, hierarchies[0]) @@ -81,21 +75,3 @@ def parse_layout(image: np.array): rects = remove_overlapping(rects) return list(map(Rectangle.from_xywh, rects)) - - -# def annotate_layout_in_pdf(page, return_rects=False, show=False): - -# #page = pdf2image.convert_from_path(pdf_path, first_page=page_index + 1, last_page=page_index + 1)[0] -# #page = np.array(page) - -# rects = parse_layout(page) - -# if return_rects: -# return rects, page -# elif show: -# page = draw_rectangles(page, rects) -# vizlogger.debug(page, "layout10_output.png") -# show_mpl(page) -# else: -# page = draw_rectangles(page, rects) -# return page diff --git a/cv_analysis/redaction_detection.py b/cv_analysis/redaction_detection.py index 3c5bf5f..b9d40d8 100644 --- a/cv_analysis/redaction_detection.py +++ b/cv_analysis/redaction_detection.py @@ -5,16 +5,12 @@ import numpy as np import pdf2image from iteration_utilities import starfilter, first -from cv_analysis.utils.display import show_mpl -from cv_analysis.utils.draw import draw_contours from cv_analysis.utils.filters import is_large_enough, is_filled, is_boxy from cv_analysis.utils.visual_logging import vizlogger def is_likely_redaction(contour, hierarchy, min_area): - return ( - is_filled(hierarchy) and is_boxy(contour) and is_large_enough(contour, min_area) - ) + return is_filled(hierarchy) and is_boxy(contour) and is_large_enough(contour, min_area) def find_redactions(image: np.array, min_normalized_area=200000): @@ -31,9 +27,7 @@ def find_redactions(image: np.array, min_normalized_area=200000): thresh = cv2.threshold(blurred, 252, 255, cv2.THRESH_BINARY)[1] vizlogger.debug(blurred, "redactions04_threshold.png") - contours, hierarchies = cv2.findContours( - thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE - ) + contours, hierarchies = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) try: contours = map( @@ -46,16 +40,3 @@ def find_redactions(image: np.array, min_normalized_area=200000): return list(contours) except: return [] - - -# def annotate_redactions_in_pdf(page, show=False): - -# #page = pdf2image.convert_from_path(pdf_path, first_page=page_index + 1, last_page=page_index + 1)[0] -# #page = np.array(page) - -# redaction_contours = find_redactions(page) -# page = draw_contours(page, redaction_contours) -# vizlogger.debug(page, "redactions05_output.png") - -# if show: -# show_mpl(page) diff --git a/cv_analysis/server/stream.py b/cv_analysis/server/stream.py index ae66475..a52eac2 100644 --- a/cv_analysis/server/stream.py +++ b/cv_analysis/server/stream.py @@ -7,7 +7,7 @@ from pyinfra.server.utils import make_streamable_and_wrap_in_packing_logic from cv_analysis.server.format import make_formatter from cv_analysis.utils.logging import get_logger -from cv_analysis.utils.preprocessing import open_img_from_bytes +from cv_analysis.utils.preprocessing import open_img logger = get_logger() @@ -26,7 +26,7 @@ def make_streamable_analysis_fn(analysis_fn: Callable): def analyse(data: bytes, metadata: dict): - image = open_img_from_bytes(gzip.decompress(data)) + image = open_img(gzip.decompress(data))[0] dpi = metadata["image_info"]["dpi"] width, height, rotation = itemgetter("width", "height", "rotation")(metadata["page_info"]) diff --git a/cv_analysis/table_parsing.py b/cv_analysis/table_parsing.py index 52d5292..9375a0f 100644 --- a/cv_analysis/table_parsing.py +++ b/cv_analysis/table_parsing.py @@ -15,9 +15,7 @@ from cv_analysis.layout_parsing import parse_layout def add_external_contours(image, image_h_w_lines_only): - contours, _ = cv2.findContours( - image_h_w_lines_only, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE - ) + contours, _ = cv2.findContours(image_h_w_lines_only, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) for cnt in contours: x, y, w, h = cv2.boundingRect(cnt) cv2.rectangle(image, (x, y), (x + w, y + h), 255, 1) @@ -82,9 +80,7 @@ def isolate_vertical_and_horizontal_components(img_bin): img_bin_extended = img_bin_h | img_bin_v th1, img_bin_extended = cv2.threshold(img_bin_extended, 120, 255, cv2.THRESH_BINARY) - img_bin_final = cv2.dilate( - img_bin_extended, np.ones((1, 1), np.uint8), iterations=1 - ) + img_bin_final = cv2.dilate(img_bin_extended, np.ones((1, 1), np.uint8), iterations=1) # add contours before lines are extended by blurring img_bin_final = add_external_contours(img_bin_final, img_lines_raw) @@ -137,9 +133,7 @@ def turn_connected_components_into_rects(image): x1, y1, w, h, area = stat return area > 2000 and w > 35 and h > 25 - _, _, stats, _ = cv2.connectedComponentsWithStats( - ~image, connectivity=8, ltype=cv2.CV_32S - ) + _, _, stats, _ = cv2.connectedComponentsWithStats(~image, connectivity=8, ltype=cv2.CV_32S) stats = np.vstack(list(filter(is_large_enough, stats))) return stats[:, :-1][2:] @@ -149,7 +143,7 @@ def parse_tables(image: np.array, show=False): """Runs the full table parsing process. Args: - image (np.array): single PDF page, opened as PIL.Image object and converted to a numpy array + image (np.array): single PDF page, converted to a numpy array Returns: list: list of rectangles corresponding to table cells diff --git a/cv_analysis/utils/deskew.py b/cv_analysis/utils/deskew.py deleted file mode 100644 index 98f3de3..0000000 --- a/cv_analysis/utils/deskew.py +++ /dev/null @@ -1,87 +0,0 @@ -import numpy as np -from scipy.ndimage import rotate as rotate_ -import cv2 - -from cv_analysis.config import CONFIG - - -def rotate_straight(im: np.array, skew_angle: int) -> np.array: - h, w = im.shape[:2] - center = (w // 2, h // 2) - M = cv2.getRotationMatrix2D(center, skew_angle, 1.0) - rotated = cv2.warpAffine( - im, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE - ) - return rotated - - -def find_score(arr, angle): - data = rotate_(arr, angle, reshape=False, order=0, mode=CONFIG.deskew.mode) - hist = np.sum(data, axis=1) - score = np.sum((hist[1:] - hist[:-1]) ** 2) - return score - - -def find_best_angle(page): - lim = CONFIG.deskew.max_abs_angle - delta = CONFIG.deskew.delta - angles = np.arange(-lim, lim + delta, delta) - scores = [find_score(page, angle) for angle in angles] - best_angle = angles[scores.index(max(scores))] - return best_angle - - -def preprocess(arr: np.array): - if len(arr.shape) > 2: - arr = cv2.cvtColor(arr, cv2.COLOR_BGR2GRAY) - arr = cv2.fastNlMeansDenoising(arr, h=CONFIG.deskew.filter_strength_h) - return arr - - -def rotate(page, angle): - rotated = rotate_(page, angle, reshape=False, order=0, mode="nearest") - return rotated - - -def deskew_histbased(page: np.array): - page = preprocess(page) - best_angle = round(find_best_angle(page), 3) - - if CONFIG.deskew.verbose: - print("Skew angle from pixel histogram: {}".format(best_angle)) - - rotated = rotate(page, best_angle) - return (rotated, best_angle) - - -def needs_deskew(page: np.array) -> bool: - """ - Makes use of 'row-wise mean difference' - the difference between neighboring - on left and right halves - """ - - def split_rowmean_diff(page): - width = page.shape[1] - cutpoint = int(width / 2) - left = page[:, :cutpoint] - right = page[:, cutpoint:] - leftmeans = np.mean(left, axis=1) - rightmeans = np.mean(right, axis=1) - return rightmeans - leftmeans - - unrotated_score = np.mean(np.abs(split_rowmean_diff(page))) - angles = [-CONFIG.deskew.test_delta, CONFIG.deskew.test_delta] - scores = [ - np.mean(np.abs(split_rowmean_diff(rotate(page, angle)))) for angle in angles - ] - print(unrotated_score, scores) - return unrotated_score > min(scores) - - -if CONFIG.deskew.function == "hist": - deskew = lambda page: deskew_histbased(page) if needs_deskew(page) else (page, 0) -elif CONFIG.deskew.function == "identity": - deskew = lambda page: (page, None) -else: - raise ValueError( - "'{CONFIG.deskew.function}' is not a valid parameter value for CONFIG.deskew.function" - ) diff --git a/cv_analysis/utils/display.py b/cv_analysis/utils/display.py index 999c9a2..f5d9285 100644 --- a/cv_analysis/utils/display.py +++ b/cv_analysis/utils/display.py @@ -1,26 +1,34 @@ +from numpy import resize import cv2 from matplotlib import pyplot as plt -def show_mpl(image): +def show_image_cv2(image, maxdim=700): + h, w, c = image.shape + maxhw = max(h, w) + if maxhw > maxdim: + ratio = maxdim / maxhw + h = int(h * ratio) + w = int(w * ratio) + img = cv2.resize(image, (h, w)) + cv2.imshow("", img) + cv2.waitKey(0) + cv2.destroyAllWindows() + + +def show_image_mpl(image): fig, ax = plt.subplots(1, 1) fig.set_size_inches(20, 20) ax.imshow(image, cmap="gray") plt.show() -def save_mpl(image, path): - # fig, ax = plt.subplots(1, 1) - # figure = plt.gcf() - # figure.set_size_inches(16,12) - fig, ax = plt.subplots(1, 1) - fig.set_size_inches(20, 20) - ax.imshow(image, cmap="gray") - # plt.close() - plt.savefig(path) - plt.close() +def show_image(image, backend="m"): + if backend.startswith("m"): + show_image_mpl(image) + else: + show_image_cv2(image) -def show_cv2(image): - cv2.imshow("", image) - cv2.waitKey(0) +def save_image(image, path): + cv2.imwrite(path, image) diff --git a/cv_analysis/utils/logging.py b/cv_analysis/utils/logging.py index 6fc280f..51be0fb 100644 --- a/cv_analysis/utils/logging.py +++ b/cv_analysis/utils/logging.py @@ -8,9 +8,7 @@ from cv_analysis.config import CONFIG def make_logger_getter(): logger = logging.getLogger(__name__) logger.setLevel(logging.getLevelName(CONFIG.service.logging_level)) - formatter = logging.Formatter( - fmt="%(asctime)s %(levelname)s: %(message)s", datefmt="%d.%m.%Y - %H:%M:%S" - ) + formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s: %(message)s", datefmt="%d.%m.%Y - %H:%M:%S") ch = logging.StreamHandler(sys.stdout) ch.setLevel(logging.getLevelName(CONFIG.service.logging_level)) diff --git a/cv_analysis/utils/open_pdf.py b/cv_analysis/utils/open_pdf.py new file mode 100644 index 0000000..d704ba4 --- /dev/null +++ b/cv_analysis/utils/open_pdf.py @@ -0,0 +1,27 @@ +from numpy import array, ndarray +import pdf2image +from PIL import Image + +from cv_analysis.utils.preprocessing import preprocess_page_array + + +def open_pdf(pdf, first_page=0, last_page=None): + + first_page += 1 + last_page = None if last_page is None else last_page + 1 + + if type(pdf) == str: + if pdf.lower().endswith((".png", ".jpg", ".jpeg")): + pages = [Image.open(pdf)] + elif pdf.lower().endswith(".pdf"): + pages = pdf2image.convert_from_path(pdf, first_page=first_page, last_page=last_page) + else: + raise IOError("Invalid file extension. Accepted filetypes:\n\t.png\n\t.jpg\n\t.jpeg\n\t.pdf") + elif type(pdf) == bytes: + pages = pdf2image.convert_from_bytes(pdf, first_page=first_page, last_page=last_page) + elif type(pdf) in {list, ndarray}: + return pdf + + pages = [preprocess_page_array(array(p)) for p in pages] + + return pages diff --git a/cv_analysis/utils/post_processing.py b/cv_analysis/utils/post_processing.py index 46da1dc..1749f2d 100644 --- a/cv_analysis/utils/post_processing.py +++ b/cv_analysis/utils/post_processing.py @@ -18,21 +18,11 @@ def remove_overlapping(rectangles): def remove_included(rectangles): def included(a, b): - return ( - b.xmin >= a.xmin - and b.ymin >= a.ymin - and b.xmax <= a.xmax - and b.ymax <= a.ymax - ) + return b.xmin >= a.xmin and b.ymin >= a.ymin and b.xmax <= a.xmax and b.ymax <= a.ymax def includes(a, b, tol=3): """does a include b?""" - return ( - b.xmin + tol >= a.xmin - and b.ymin + tol >= a.ymin - and b.xmax - tol <= a.xmax - and b.ymax - tol <= a.ymax - ) + return b.xmin + tol >= a.xmin and b.ymin + tol >= a.ymin and b.xmax - tol <= a.xmax and b.ymax - tol <= a.ymax def is_not_included(rect, rectangles): return not any(includes(r2, rect) for r2 in rectangles if not rect == r2) @@ -110,9 +100,7 @@ def __remove_isolated_sorted(rectangles): def remove_isolated(rectangles, input_sorted=False): - return (__remove_isolated_sorted if input_sorted else __remove_isolated_unsorted)( - rectangles - ) + return (__remove_isolated_sorted if input_sorted else __remove_isolated_unsorted)(rectangles) Rectangle = namedtuple("Rectangle", "xmin ymin xmax ymax") diff --git a/cv_analysis/utils/preprocessing.py b/cv_analysis/utils/preprocessing.py index d51139a..c3269d4 100644 --- a/cv_analysis/utils/preprocessing.py +++ b/cv_analysis/utils/preprocessing.py @@ -1,41 +1,29 @@ -from io import BytesIO -from numpy import array, ndarray -import pdf2image -from PIL import Image +from numpy import frombuffer, ndarray import cv2 -def preprocess_pdf_image(page): +def preprocess_page_array(page): if len(page.shape) > 2: page = cv2.cvtColor(page, cv2.COLOR_BGR2GRAY) page = cv2.fastNlMeansDenoising(page, h=3) return page -def open_pdf(pdf, first_page=0, last_page=None): +def page2image(page): - first_page += 1 - last_page = None if last_page is None else last_page + 1 - - if type(pdf) == str: - if pdf.lower().endswith((".png", ".jpg", ".jpeg")): - pages = [Image.open(pdf)] - else: # assume pdf as default file type for a path argument - pages = pdf2image.convert_from_path( - pdf, first_page=first_page, last_page=last_page + if type(page) == bytes: + page = frombuffer(page) + elif type(page) == ndarray: + page = page + elif type(page) == str: + if page.lower().endswith((".png", ".jpg", ".jpeg")): + page = cv2.imread(page) + else: + raise IOError( + "PDFs are not a valid input type for cv-analysis." + " Use PNGs for tests and NumPy arrays for deployment." ) - elif type(pdf) == bytes: - pages = pdf2image.convert_from_bytes( - pdf, first_page=first_page, last_page=last_page - ) - elif type(pdf) in {list, ndarray}: - return pdf + else: + raise TypeError("Incompatible datatype. Expected bytes, numpy.ndarray, or path to an image file.") - pages = [preprocess_pdf_image(array(p)) for p in pages] - - return pages - - -def open_img_from_bytes(bytes_obj: bytes): - page = Image.open(BytesIO(bytes_obj)) - return preprocess_pdf_image(array(page)) + return preprocess_page_array(page) diff --git a/cv_analysis/utils/test_metrics.py b/cv_analysis/utils/test_metrics.py index fd0eca6..8df3d00 100644 --- a/cv_analysis/utils/test_metrics.py +++ b/cv_analysis/utils/test_metrics.py @@ -75,11 +75,7 @@ def compute_document_score(results_dict, annotation_dict): scores = [] for i in range(len(annotation_dict["pages"])): - scores.append( - compute_page_iou( - results_dict["pages"][i]["cells"], annotation_dict["pages"][i]["cells"] - ) - ) + scores.append(compute_page_iou(results_dict["pages"][i]["cells"], annotation_dict["pages"][i]["cells"])) scores = np.array(scores) doc_score = np.average(scores, weights=page_weights) diff --git a/cv_analysis/utils/visual_logging.py b/cv_analysis/utils/visual_logging.py index 983b546..e088dbe 100644 --- a/cv_analysis/utils/visual_logging.py +++ b/cv_analysis/utils/visual_logging.py @@ -1,6 +1,6 @@ import os from cv_analysis.config import CONFIG -from cv_analysis.utils.display import save_mpl +from cv_analysis.utils.display import save_image class VisualLogger: @@ -12,7 +12,7 @@ class VisualLogger: def _save(self, img, name): output_path = os.path.join(self.output_folder, name) - save_mpl(img, output_path) + save_image(img, output_path) def info(self, img, name): if self._level_is_info(): @@ -36,6 +36,4 @@ class VisualLogger: return self.level == "ALL" -vizlogger = VisualLogger( - CONFIG.visual_logging.level, CONFIG.visual_logging.output_folder -) +vizlogger = VisualLogger(CONFIG.visual_logging.level, CONFIG.visual_logging.output_folder) diff --git a/scripts/annotate.py b/scripts/annotate.py index cac9b45..e899ec8 100644 --- a/scripts/annotate.py +++ b/scripts/annotate.py @@ -8,9 +8,9 @@ python scripts/annotate.py /home/iriley/Documents/pdf/scanned/10.pdf 5 --type fi import argparse -from cv_analysis.utils.display import show_mpl +from cv_analysis.utils.display import show_image from cv_analysis.utils.draw import draw_contours, draw_rectangles -from cv_analysis.utils.preprocessing import open_pdf +from cv_analysis.utils.open_pdf import open_pdf from cv_analysis.utils.visual_logging import vizlogger @@ -28,7 +28,7 @@ def annotate_page(page_image, analysis_function, drawing_function, name="tmp.png result = analysis_function(page_image) page_image = drawing_function(page_image, result) vizlogger.debug(page_image, "redactions05_output.png") - show_mpl(page_image) + show_image(page_image) if __name__ == "__main__": @@ -46,5 +46,6 @@ if __name__ == "__main__": from cv_analysis.layout_parsing import parse_layout as analyze elif args.type == "figure": from cv_analysis.figure_detection.figure_detection_pipeline import make_figure_detection_pipeline + analyze = make_figure_detection_pipeline() annotate_page(page, analyze, draw, name=name, show=args.show) diff --git a/scripts/deskew_demo.py b/scripts/deskew_demo.py deleted file mode 100644 index b09a342..0000000 --- a/scripts/deskew_demo.py +++ /dev/null @@ -1,50 +0,0 @@ -# sample usage: python3 scripts/deskew_demo.py /path/to/crooked.pdf 0 -import argparse -import numpy as np -import pdf2image -from PIL import Image - -from cv_analysis.utils.deskew import deskew_histbased # , deskew_linebased -from cv_analysis.utils.display import show_mpl -from cv_analysis.utils.draw import draw_stats -from cv_analysis.table_parsing import parse_tables - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("pdf_path") - parser.add_argument("page_index", type=int) - parser.add_argument("--save_path") - - args = parser.parse_args() - - return args - - -if __name__ == "__main__": - args = parse_args() - page = pdf2image.convert_from_path(args.pdf_path, first_page=args.page_index + 1, last_page=args.page_index + 1)[0] - page = np.array(page) - - show_mpl(page) - # page_ = deskew_linebased(page, verbose=True) - # show_mpl(page_) - page_corr, _ = deskew_histbased(page, verbose=True) - show_mpl(page_corr) - if args.save_path: - page_ = Image.fromarray(page).convert("RGB") - page_.save(args.save_path.replace(".pdf", "_uncorrected.pdf")) - page_corr_ = Image.fromarray(page_corr).convert("RGB") - page_corr_.save(args.save_path.replace(".pdf", "_corrected.pdf")) - # annotate_tables_in_pdf(args.pdf_path, page_index=args.page_index) - stats = parse_tables(page) - page = draw_stats(page, stats) - show_mpl(page) - stats_corr = parse_tables(page_corr) - page_corr = draw_stats(page_corr, stats_corr) - show_mpl(page_corr) - if args.save_path: - page = Image.fromarray(page).convert("RGB") - page.save(args.save_path.replace(".pdf", "_uncorrected_annotated.pdf")) - page_corr = Image.fromarray(page_corr).convert("RGB") - page_corr.save(args.save_path.replace(".pdf", "_corrected_annotated.pdf")) diff --git a/scripts/pyinfra_mock.py b/scripts/pyinfra_mock.py index 1717521..6d45b4d 100644 --- a/scripts/pyinfra_mock.py +++ b/scripts/pyinfra_mock.py @@ -1,16 +1,11 @@ import argparse -import base64 import gzip -import io -import json from operator import itemgetter from typing import List import fitz import pdf2image -from PIL import Image from funcy import lmap, compose, pluck -from funcy import lpluck from pyinfra.default_objects import get_component_factory @@ -45,13 +40,13 @@ def draw_cells_on_page(cells: List[dict], page): def annotate_results_on_pdf(results, pdf_path, result_path): - open_pdf = fitz.open(pdf_path) + opened_pdf = fitz.open(pdf_path) metadata_per_page = pluck("metadata", results) - for page, metadata in zip(open_pdf, metadata_per_page): + for page, metadata in zip(opened_pdf, metadata_per_page): if metadata: draw_cells_on_page(metadata["cells"], page) - open_pdf.save(result_path) + opened_pdf.save(result_path) def main(args): diff --git a/test/fixtures/figure_detection.py b/test/fixtures/figure_detection.py index 386e4a7..a1ad9cd 100644 --- a/test/fixtures/figure_detection.py +++ b/test/fixtures/figure_detection.py @@ -3,18 +3,17 @@ import textwrap import cv2 import numpy as np import pytest -from PIL import Image from lorem_text import lorem from funcy import first from cv_analysis.figure_detection.figure_detection_pipeline import ( make_figure_detection_pipeline, ) -from cv_analysis.utils.display import show_mpl +from cv_analysis.utils.display import show_image @pytest.fixture def page_with_images(random_image, n_images, background): - page_image = Image.fromarray(background.astype("uint8")).convert("RGB") + # page_image = Image.fromarray(background.astype("uint8")).convert("RGB") page_image = paste_image(page_image, random_image, (200, 200)) if n_images == 2: page_image = paste_image(page_image, random_image, (1000, 2600)) @@ -32,14 +31,10 @@ def page_with_text(background, font_scale, font_style, text_types): cursor = (image.shape[1] // 2, 70) image = paste_text(image, cursor, font_scale, font_style, y_stop=body_height) cursor = (50, body_height + 70) - image = paste_text( - image, cursor, font_scale, font_style, y_stop=body_height * 2 - ) + image = paste_text(image, cursor, font_scale, font_style, y_stop=body_height * 2) if "caption" in text_types: cursor = (image.shape[1] // 2, image.shape[0] - 100) - image = paste_text( - image, cursor, font_scale, font_style, y_stop=body_height * 3 - ) + image = paste_text(image, cursor, font_scale, font_style, y_stop=body_height * 3) return image @@ -67,9 +62,7 @@ def paste_text(image: np.ndarray, cursor, font_scale, font_style, y_stop): def paste_text_at_cursor(x_start, y_start, y_stop): # TODO: adjust incorrect right margin text = lorem.paragraphs(1) * 200 - (dx, dy), base = cv2.getTextSize( - text, fontFace=font_style, fontScale=font_scale, thickness=1 - ) + (dx, dy), base = cv2.getTextSize(text, fontFace=font_style, fontScale=font_scale, thickness=1) dy += base # char_width = dx // len(text) text = textwrap.fill(text=text, width=(dx // page_width)) @@ -95,6 +88,7 @@ def paste_text(image: np.ndarray, cursor, font_scale, font_style, y_stop): def paste_image(page_image, image, coords): - image = Image.fromarray(image.astype("uint8")).convert("RGBA") - page_image.paste(image, coords) + h, w = image.shape[:2] + x, y = coords + page_image[x : x + h, y : y + w] = image return page_image diff --git a/test/fixtures/server.py b/test/fixtures/server.py index 982b89a..0ecec7e 100644 --- a/test/fixtures/server.py +++ b/test/fixtures/server.py @@ -2,8 +2,8 @@ import gzip import io import numpy as np +import cv2 import pytest -from PIL import Image from funcy import first from cv_analysis.utils.structures import Rectangle @@ -12,7 +12,7 @@ from incl.pyinfra.pyinfra.server.packing import bytes_to_string @pytest.fixture def random_image_as_bytes_and_compressed(random_image): - image = Image.fromarray(random_image.astype("uint8")).convert("RGBA") + image = cv2.cvtColor(random_image.astype("uint8"), cv2.COLOR_RGB2RGBA) img_byte_arr = io.BytesIO() image.save(img_byte_arr, format="PNG") return gzip.compress(img_byte_arr.getvalue()) diff --git a/test/fixtures/table_parsing.py b/test/fixtures/table_parsing.py index 94e8cb0..ccd5207 100644 --- a/test/fixtures/table_parsing.py +++ b/test/fixtures/table_parsing.py @@ -6,7 +6,7 @@ from funcy import first from cv_analysis.locations import TEST_DATA_DIR from cv_analysis.utils.draw import draw_rectangles -from cv_analysis.utils.preprocessing import open_pdf +from cv_analysis.utils.open_pdf import open_pdf from test.fixtures.figure_detection import paste_text @@ -24,9 +24,7 @@ def expected_table_annotation(test_file_index): @pytest.fixture -def page_with_table( - background, table_shape, table_style, n_tables, line_thickness, line_type -): +def page_with_table(background, table_shape, table_style, n_tables, line_thickness, line_type): page = draw_table( background, (100, 100), @@ -36,9 +34,7 @@ def page_with_table( line_type=line_type, ) if n_tables == 2: - page = draw_table( - page, (200, 2000), table_shape, table_style, line_thickness, line_type - ) + page = draw_table(page, (200, 2000), table_shape, table_style, line_thickness, line_type) return page @@ -205,9 +201,7 @@ def expected_gold_page_with_table(page_with_table, n_tables): return result -def draw_table( - page, table_position, table_shape, table_style, line_thickness, line_type -): +def draw_table(page, table_position, table_shape, table_style, line_thickness, line_type): bbox_table = (*table_position, 1500, 1000) page = draw_grid_lines( page, diff --git a/test/unit_tests/figure_detection/text_test.py b/test/unit_tests/figure_detection/text_test.py index 6983d79..794763b 100644 --- a/test/unit_tests/figure_detection/text_test.py +++ b/test/unit_tests/figure_detection/text_test.py @@ -6,7 +6,7 @@ from cv_analysis.figure_detection.text import ( remove_primary_text_regions, apply_threshold_to_image, ) -from cv_analysis.utils.display import show_mpl +from cv_analysis.utils.display import show_image from test.utils.utils import powerset @@ -25,33 +25,19 @@ class TestFindPrimaryTextRegions: np.testing.assert_equal(result_page, apply_threshold_to_image(page_with_images)) @pytest.mark.parametrize("font_scale", [1, 1.5, 2]) - @pytest.mark.parametrize( - "font_style", [cv2.FONT_HERSHEY_SIMPLEX, cv2.FONT_HERSHEY_COMPLEX] - ) + @pytest.mark.parametrize("font_style", [cv2.FONT_HERSHEY_SIMPLEX, cv2.FONT_HERSHEY_COMPLEX]) @pytest.mark.parametrize("text_types", powerset(["body", "header", "caption"])) - def test_page_with_only_text_gets_text_removed( - self, page_with_text, error_tolerance - ): + def test_page_with_only_text_gets_text_removed(self, page_with_text, error_tolerance): result_page = remove_primary_text_regions(page_with_text) - relative_error = ( - np.sum(result_page != apply_threshold_to_image(page_with_text)) - / result_page.size - ) + relative_error = np.sum(result_page != apply_threshold_to_image(page_with_text)) / result_page.size assert relative_error <= error_tolerance @pytest.mark.parametrize("image_size", [(200, 200), (500, 500), (800, 800)]) @pytest.mark.parametrize("n_images", [1, 2]) @pytest.mark.parametrize("font_scale", [1, 1.5, 2]) - @pytest.mark.parametrize( - "font_style", [cv2.FONT_HERSHEY_SIMPLEX, cv2.FONT_HERSHEY_COMPLEX] - ) + @pytest.mark.parametrize("font_style", [cv2.FONT_HERSHEY_SIMPLEX, cv2.FONT_HERSHEY_COMPLEX]) @pytest.mark.parametrize("text_types", powerset(["body", "header", "caption"])) - def test_page_with_images_and_text_keeps_images( - self, page_with_images_and_text, error_tolerance - ): + def test_page_with_images_and_text_keeps_images(self, page_with_images_and_text, error_tolerance): result_page = remove_primary_text_regions(page_with_images_and_text) - relative_error = ( - np.sum(result_page != apply_threshold_to_image(page_with_images_and_text)) - / result_page.size - ) + relative_error = np.sum(result_page != apply_threshold_to_image(page_with_images_and_text)) / result_page.size assert relative_error <= error_tolerance diff --git a/test/unit_tests/server/formatted_stream_fn_test.py b/test/unit_tests/server/formatted_stream_fn_test.py index 2cf1e3a..88526fe 100644 --- a/test/unit_tests/server/formatted_stream_fn_test.py +++ b/test/unit_tests/server/formatted_stream_fn_test.py @@ -6,9 +6,7 @@ from cv_analysis.server.stream import make_streamable_analysis_fn @pytest.mark.parametrize("operation", ["mock"]) @pytest.mark.parametrize("image_size", [(200, 200), (500, 500), (800, 800)]) -def test_make_analysis_fn( - analysis_fn_mock, random_image_metadata_package, expected_analyse_metadata -): +def test_make_analysis_fn(analysis_fn_mock, random_image_metadata_package, expected_analyse_metadata): analyse = make_streamable_analysis_fn(analysis_fn_mock) results = first(analyse(random_image_metadata_package))