Refactoring

Various
This commit is contained in:
Matthias Bisping 2023-01-04 17:35:58 +01:00
parent 97fb4b645d
commit 9d2f166fbf
5 changed files with 63 additions and 71 deletions

View File

@ -11,28 +11,23 @@ from pdf2img.default_objects.image import ImagePlus, ImageInfo
from pdf2img.default_objects.rectangle import RectanglePlus
def get_analysis_pipeline(operation, table_parsing_skip_pages_without_images):
if operation == "table":
return make_analysis_pipeline(
parse_tables,
table_parsing_formatter,
dpi=200,
skip_pages_without_images=table_parsing_skip_pages_without_images,
)
elif operation == "figure":
return make_analysis_pipeline(detect_figures, figure_detection_formatter, dpi=200)
def make_analysis_pipeline_for_element_type(segment_type, **kwargs):
if segment_type == "table":
return make_analysis_pipeline(parse_tables, table_parsing_formatter, dpi=200, **kwargs)
elif segment_type == "figure":
return make_analysis_pipeline(detect_figures, figure_detection_formatter, dpi=200, **kwargs)
else:
raise
raise ValueError(f"Unknown segment type {segment_type}.")
def make_analysis_pipeline(analysis_fn, formatter, dpi, skip_pages_without_images=False):
def analyse_pipeline(pdf: bytes, index=None):
def analysis_pipeline(pdf: bytes, index=None):
def parse_page(page: ImagePlus):
image = page.asarray()
rects = analysis_fn(image)
if not rects:
rectangles = analysis_fn(image)
if not rectangles:
return
infos = formatter(rects, page, dpi)
infos = formatter(rectangles, page, dpi)
return infos
pages = convert_pages_to_images(pdf, index=index, dpi=dpi, skip_pages_without_images=skip_pages_without_images)
@ -40,22 +35,26 @@ def make_analysis_pipeline(analysis_fn, formatter, dpi, skip_pages_without_image
yield from flatten(filter(truth, results))
return analyse_pipeline
return analysis_pipeline
def table_parsing_formatter(rects, page: ImagePlus, dpi):
def format_rect(rect: Rectangle):
rect_plus = RectanglePlus.from_pixels(*rect.xyxy(), page.info, alpha=False, dpi=dpi)
return rect_plus.asdict(derotate=True)
def table_parsing_formatter(rectangles, page: ImagePlus, dpi):
def format_rectangle(rectangle: Rectangle):
rectangle_plus = RectanglePlus.from_pixels(*rectangle_to_xyxy(rectangle), page.info, alpha=False, dpi=dpi)
return rectangle_plus.asdict(derotate=True)
bboxes = lmap(format_rect, rects)
bboxes = lmap(format_rectangle, rectangles)
return {"pageInfo": page.asdict(natural_index=True), "tableCells": bboxes}
def figure_detection_formatter(rects, page, dpi):
def format_rect(rect: Rectangle):
rect_plus = RectanglePlus.from_pixels(*rect.xyxy(), page.info, alpha=False, dpi=dpi)
def figure_detection_formatter(rectangles, page, dpi):
def format_rectangle(rectangle: Rectangle):
rect_plus = RectanglePlus.from_pixels(*rectangle_to_xyxy(rectangle), page.info, alpha=False, dpi=dpi)
return asdict(ImageInfo(page.info, rect_plus.asbbox(derotate=False), rect_plus.alpha))
return lmap(format_rect, rects)
return lmap(format_rectangle, rectangles)
def rectangle_to_xyxy(rectangle: Rectangle):
return rectangle.x1, rectangle.y1, rectangle.x2, rectangle.y2

View File

@ -48,38 +48,17 @@ class Rectangle:
def h(self):
return self.__h
def __str__(self):
return dumps(self.to_dict())
def __repr__(self):
return str(self)
def __iter__(self):
return list(self.to_dict().values()).__iter__()
def __eq__(self, other: Rectangle):
return all([self.x1 == other.x1, self.y1 == other.y1, self.w == other.w, self.h == other.h])
def __hash__(self):
return hash((self.x1, self.y1, self.x2, self.y2))
@classmethod
def from_xywh(cls, xywh: Iterable[Coord], discrete=True):
"""Creates a rectangle from a point and a width and height."""
"""Creates a rectangle from a point, width and height."""
x1, y1, w, h = xywh
x2 = x1 + w
y2 = y1 + h
return cls(x1, y1, x2, y2, discrete=discrete)
def xyxy(self):
return self.x1, self.y1, self.x2, self.y2
def xywh(self):
return self.x1, self.y1, self.w, self.h
def to_dict(self):
return {"x": self.x1, "y": self.y1, "width": self.w, "height": self.h}
def intersection(self, other):
"""Calculates the intersection of this and another rectangle."""
return intersection(self, other)

View File

@ -2,28 +2,27 @@ import argparse
import json
from pathlib import Path
from cv_analysis.server.pipeline import get_analysis_pipeline
from loguru import logger
from cv_analysis.server.pipeline import make_analysis_pipeline_for_element_type
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("pdf")
parser.add_argument("--type", "-t", choices=["table", "layout", "figure"], required=True)
parser.add_argument("pdf", type=Path)
parser.add_argument("--element_type", "-t", choices=["table", "figure"], required=True)
return parser.parse_args()
def main(args):
analysis_fn = make_analysis_pipeline_for_element_type(args.element_type)
logger.info(f"Analysing document for {args.element_type}s...")
results = list(analysis_fn(args.pdf.read_bytes()))
print(json.dumps(results, indent=2))
if __name__ == "__main__":
args = parse_args()
analysis_fn = get_analysis_pipeline(args.type)
with open(args.pdf, "rb") as f:
pdf_bytes = f.read()
results = list(analysis_fn(pdf_bytes))
folder = Path(args.pdf).parent
file_stem = Path(args.pdf).stem
with open(f"{folder}/{file_stem}_{args.type}.json", "w+") as f:
json.dump(results, f, indent=2)
main(parse_args())

View File

@ -4,7 +4,7 @@ import logging
from operator import itemgetter
from cv_analysis.config import get_config
from cv_analysis.server.pipeline import get_analysis_pipeline
from cv_analysis.server.pipeline import make_analysis_pipeline_for_segment_type
from cv_analysis.utils.banner import make_art
from pyinfra import config as pyinfra_config
from pyinfra.queue.queue_manager import QueueManager
@ -31,7 +31,10 @@ def analysis_callback(queue_message: dict):
should_publish_result = True
object_bytes = gzip.decompress(storage.get_object(bucket, object_name))
analysis_fn = get_analysis_pipeline(operation, CV_CONFIG.table_parsing_skip_pages_without_images)
analysis_fn = make_analysis_pipeline_for_segment_type(
operation,
skip_pages_without_images=CV_CONFIG.table_parsing_skip_pages_without_images,
)
results = analysis_fn(object_bytes)
response = {**queue_message, "data": list(results)}

View File

@ -2,8 +2,11 @@ from itertools import starmap
import cv2
import pytest
from funcy import lmap, compose
from cv_analysis.table_parsing import parse_tables
from cv_analysis.utils import lift
from cv_analysis.utils.structures import Rectangle
from cv_analysis.utils.test_metrics import compute_document_score
@ -12,8 +15,9 @@ from cv_analysis.utils.test_metrics import compute_document_score
def test_table_parsing_on_client_pages(
score_threshold, client_page_with_table, expected_table_annotation, test_file_index
):
result = [rect.to_dict() for rect in parse_tables(client_page_with_table)]
formatted_result = {"pages": [{"page": str(test_file_index), "cells": result}]}
results = compose(lift(rectangle_to_dict), parse_tables)(client_page_with_table)
formatted_result = {"pages": [{"cells": results}]}
score = compute_document_score(formatted_result, expected_table_annotation)
@ -25,6 +29,14 @@ def error_tolerance(line_thickness):
return line_thickness * 7
def rectangle_to_dict(rectangle: Rectangle):
return {"x": rectangle.x1, "y": rectangle.y1, "width": rectangle.w, "height": rectangle.h}
def rectangle_to_xywh(rectangle: Rectangle):
return rectangle.x1, rectangle.y1, rectangle.w, rectangle.h
@pytest.mark.parametrize("line_thickness", [1, 2, 3])
@pytest.mark.parametrize("line_type", [cv2.LINE_4, cv2.LINE_AA, cv2.LINE_8])
@pytest.mark.parametrize("table_style", ["closed horizontal vertical", "open horizontal vertical"])
@ -32,7 +44,7 @@ def error_tolerance(line_thickness):
@pytest.mark.parametrize("background_color", [255, 220])
@pytest.mark.parametrize("table_shape", [(5, 8)])
def test_table_parsing_on_generic_pages(page_with_table, expected_gold_page_with_table, error_tolerance):
result = [x.xywh() for x in parse_tables(page_with_table)]
result = lmap(rectangle_to_xywh, parse_tables(page_with_table))
assert (
result == expected_gold_page_with_table
or average_error(result, expected_gold_page_with_table) <= error_tolerance
@ -46,8 +58,8 @@ def test_table_parsing_on_generic_pages(page_with_table, expected_gold_page_with
@pytest.mark.parametrize("background_color", [255, 220])
@pytest.mark.parametrize("table_shape", [(5, 8)])
@pytest.mark.xfail
def test_bad_qual_table(page_with_patchy_table, expected_gold_page_with_table, error_tolerance):
result = [x.xywh() for x in parse_tables(page_with_patchy_table)]
def test_low_quality_table(page_with_patchy_table, expected_gold_page_with_table, error_tolerance):
result = lmap(rectangle_to_xywh, parse_tables(page_with_patchy_table))
assert (
result == expected_gold_page_with_table
or average_error(result, expected_gold_page_with_table) <= error_tolerance