From 9d2f166fbfdc186386107ee067207dacc67d4a13 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Wed, 4 Jan 2023 17:35:58 +0100 Subject: [PATCH] Refactoring Various --- cv_analysis/server/pipeline.py | 49 +++++++++++++-------------- cv_analysis/utils/structures.py | 23 +------------ scripts/run_analysis_pipeline.py | 33 +++++++++--------- src/serve.py | 7 ++-- test/unit_tests/table_parsing_test.py | 22 +++++++++--- 5 files changed, 63 insertions(+), 71 deletions(-) diff --git a/cv_analysis/server/pipeline.py b/cv_analysis/server/pipeline.py index 01aa05e..3de4d4f 100644 --- a/cv_analysis/server/pipeline.py +++ b/cv_analysis/server/pipeline.py @@ -11,28 +11,23 @@ from pdf2img.default_objects.image import ImagePlus, ImageInfo from pdf2img.default_objects.rectangle import RectanglePlus -def get_analysis_pipeline(operation, table_parsing_skip_pages_without_images): - if operation == "table": - return make_analysis_pipeline( - parse_tables, - table_parsing_formatter, - dpi=200, - skip_pages_without_images=table_parsing_skip_pages_without_images, - ) - elif operation == "figure": - return make_analysis_pipeline(detect_figures, figure_detection_formatter, dpi=200) +def make_analysis_pipeline_for_element_type(segment_type, **kwargs): + if segment_type == "table": + return make_analysis_pipeline(parse_tables, table_parsing_formatter, dpi=200, **kwargs) + elif segment_type == "figure": + return make_analysis_pipeline(detect_figures, figure_detection_formatter, dpi=200, **kwargs) else: - raise + raise ValueError(f"Unknown segment type {segment_type}.") def make_analysis_pipeline(analysis_fn, formatter, dpi, skip_pages_without_images=False): - def analyse_pipeline(pdf: bytes, index=None): + def analysis_pipeline(pdf: bytes, index=None): def parse_page(page: ImagePlus): image = page.asarray() - rects = analysis_fn(image) - if not rects: + rectangles = analysis_fn(image) + if not rectangles: return - infos = formatter(rects, page, dpi) + infos = formatter(rectangles, page, dpi) return infos pages = convert_pages_to_images(pdf, index=index, dpi=dpi, skip_pages_without_images=skip_pages_without_images) @@ -40,22 +35,26 @@ def make_analysis_pipeline(analysis_fn, formatter, dpi, skip_pages_without_image yield from flatten(filter(truth, results)) - return analyse_pipeline + return analysis_pipeline -def table_parsing_formatter(rects, page: ImagePlus, dpi): - def format_rect(rect: Rectangle): - rect_plus = RectanglePlus.from_pixels(*rect.xyxy(), page.info, alpha=False, dpi=dpi) - return rect_plus.asdict(derotate=True) +def table_parsing_formatter(rectangles, page: ImagePlus, dpi): + def format_rectangle(rectangle: Rectangle): + rectangle_plus = RectanglePlus.from_pixels(*rectangle_to_xyxy(rectangle), page.info, alpha=False, dpi=dpi) + return rectangle_plus.asdict(derotate=True) - bboxes = lmap(format_rect, rects) + bboxes = lmap(format_rectangle, rectangles) return {"pageInfo": page.asdict(natural_index=True), "tableCells": bboxes} -def figure_detection_formatter(rects, page, dpi): - def format_rect(rect: Rectangle): - rect_plus = RectanglePlus.from_pixels(*rect.xyxy(), page.info, alpha=False, dpi=dpi) +def figure_detection_formatter(rectangles, page, dpi): + def format_rectangle(rectangle: Rectangle): + rect_plus = RectanglePlus.from_pixels(*rectangle_to_xyxy(rectangle), page.info, alpha=False, dpi=dpi) return asdict(ImageInfo(page.info, rect_plus.asbbox(derotate=False), rect_plus.alpha)) - return lmap(format_rect, rects) + return lmap(format_rectangle, rectangles) + + +def rectangle_to_xyxy(rectangle: Rectangle): + return rectangle.x1, rectangle.y1, rectangle.x2, rectangle.y2 diff --git a/cv_analysis/utils/structures.py b/cv_analysis/utils/structures.py index f205eb7..e2a4ef2 100644 --- a/cv_analysis/utils/structures.py +++ b/cv_analysis/utils/structures.py @@ -48,38 +48,17 @@ class Rectangle: def h(self): return self.__h - def __str__(self): - return dumps(self.to_dict()) - - def __repr__(self): - return str(self) - - def __iter__(self): - return list(self.to_dict().values()).__iter__() - - def __eq__(self, other: Rectangle): - return all([self.x1 == other.x1, self.y1 == other.y1, self.w == other.w, self.h == other.h]) - def __hash__(self): return hash((self.x1, self.y1, self.x2, self.y2)) @classmethod def from_xywh(cls, xywh: Iterable[Coord], discrete=True): - """Creates a rectangle from a point and a width and height.""" + """Creates a rectangle from a point, width and height.""" x1, y1, w, h = xywh x2 = x1 + w y2 = y1 + h return cls(x1, y1, x2, y2, discrete=discrete) - def xyxy(self): - return self.x1, self.y1, self.x2, self.y2 - - def xywh(self): - return self.x1, self.y1, self.w, self.h - - def to_dict(self): - return {"x": self.x1, "y": self.y1, "width": self.w, "height": self.h} - def intersection(self, other): """Calculates the intersection of this and another rectangle.""" return intersection(self, other) diff --git a/scripts/run_analysis_pipeline.py b/scripts/run_analysis_pipeline.py index 3c8e37f..73cf806 100644 --- a/scripts/run_analysis_pipeline.py +++ b/scripts/run_analysis_pipeline.py @@ -2,28 +2,27 @@ import argparse import json from pathlib import Path -from cv_analysis.server.pipeline import get_analysis_pipeline +from loguru import logger + +from cv_analysis.server.pipeline import make_analysis_pipeline_for_element_type def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("pdf") - parser.add_argument("--type", "-t", choices=["table", "layout", "figure"], required=True) + parser.add_argument("pdf", type=Path) + parser.add_argument("--element_type", "-t", choices=["table", "figure"], required=True) return parser.parse_args() +def main(args): + + analysis_fn = make_analysis_pipeline_for_element_type(args.element_type) + + logger.info(f"Analysing document for {args.element_type}s...") + results = list(analysis_fn(args.pdf.read_bytes())) + + print(json.dumps(results, indent=2)) + + if __name__ == "__main__": - args = parse_args() - - analysis_fn = get_analysis_pipeline(args.type) - - with open(args.pdf, "rb") as f: - pdf_bytes = f.read() - - results = list(analysis_fn(pdf_bytes)) - - folder = Path(args.pdf).parent - file_stem = Path(args.pdf).stem - - with open(f"{folder}/{file_stem}_{args.type}.json", "w+") as f: - json.dump(results, f, indent=2) + main(parse_args()) diff --git a/src/serve.py b/src/serve.py index 81405bd..0df0c65 100644 --- a/src/serve.py +++ b/src/serve.py @@ -4,7 +4,7 @@ import logging from operator import itemgetter from cv_analysis.config import get_config -from cv_analysis.server.pipeline import get_analysis_pipeline +from cv_analysis.server.pipeline import make_analysis_pipeline_for_segment_type from cv_analysis.utils.banner import make_art from pyinfra import config as pyinfra_config from pyinfra.queue.queue_manager import QueueManager @@ -31,7 +31,10 @@ def analysis_callback(queue_message: dict): should_publish_result = True object_bytes = gzip.decompress(storage.get_object(bucket, object_name)) - analysis_fn = get_analysis_pipeline(operation, CV_CONFIG.table_parsing_skip_pages_without_images) + analysis_fn = make_analysis_pipeline_for_segment_type( + operation, + skip_pages_without_images=CV_CONFIG.table_parsing_skip_pages_without_images, + ) results = analysis_fn(object_bytes) response = {**queue_message, "data": list(results)} diff --git a/test/unit_tests/table_parsing_test.py b/test/unit_tests/table_parsing_test.py index e3ef8c6..db610e3 100644 --- a/test/unit_tests/table_parsing_test.py +++ b/test/unit_tests/table_parsing_test.py @@ -2,8 +2,11 @@ from itertools import starmap import cv2 import pytest +from funcy import lmap, compose from cv_analysis.table_parsing import parse_tables +from cv_analysis.utils import lift +from cv_analysis.utils.structures import Rectangle from cv_analysis.utils.test_metrics import compute_document_score @@ -12,8 +15,9 @@ from cv_analysis.utils.test_metrics import compute_document_score def test_table_parsing_on_client_pages( score_threshold, client_page_with_table, expected_table_annotation, test_file_index ): - result = [rect.to_dict() for rect in parse_tables(client_page_with_table)] - formatted_result = {"pages": [{"page": str(test_file_index), "cells": result}]} + + results = compose(lift(rectangle_to_dict), parse_tables)(client_page_with_table) + formatted_result = {"pages": [{"cells": results}]} score = compute_document_score(formatted_result, expected_table_annotation) @@ -25,6 +29,14 @@ def error_tolerance(line_thickness): return line_thickness * 7 +def rectangle_to_dict(rectangle: Rectangle): + return {"x": rectangle.x1, "y": rectangle.y1, "width": rectangle.w, "height": rectangle.h} + + +def rectangle_to_xywh(rectangle: Rectangle): + return rectangle.x1, rectangle.y1, rectangle.w, rectangle.h + + @pytest.mark.parametrize("line_thickness", [1, 2, 3]) @pytest.mark.parametrize("line_type", [cv2.LINE_4, cv2.LINE_AA, cv2.LINE_8]) @pytest.mark.parametrize("table_style", ["closed horizontal vertical", "open horizontal vertical"]) @@ -32,7 +44,7 @@ def error_tolerance(line_thickness): @pytest.mark.parametrize("background_color", [255, 220]) @pytest.mark.parametrize("table_shape", [(5, 8)]) def test_table_parsing_on_generic_pages(page_with_table, expected_gold_page_with_table, error_tolerance): - result = [x.xywh() for x in parse_tables(page_with_table)] + result = lmap(rectangle_to_xywh, parse_tables(page_with_table)) assert ( result == expected_gold_page_with_table or average_error(result, expected_gold_page_with_table) <= error_tolerance @@ -46,8 +58,8 @@ def test_table_parsing_on_generic_pages(page_with_table, expected_gold_page_with @pytest.mark.parametrize("background_color", [255, 220]) @pytest.mark.parametrize("table_shape", [(5, 8)]) @pytest.mark.xfail -def test_bad_qual_table(page_with_patchy_table, expected_gold_page_with_table, error_tolerance): - result = [x.xywh() for x in parse_tables(page_with_patchy_table)] +def test_low_quality_table(page_with_patchy_table, expected_gold_page_with_table, error_tolerance): + result = lmap(rectangle_to_xywh, parse_tables(page_with_patchy_table)) assert ( result == expected_gold_page_with_table or average_error(result, expected_gold_page_with_table) <= error_tolerance