Refactoring
Various
This commit is contained in:
parent
97fb4b645d
commit
9d2f166fbf
@ -11,28 +11,23 @@ from pdf2img.default_objects.image import ImagePlus, ImageInfo
|
||||
from pdf2img.default_objects.rectangle import RectanglePlus
|
||||
|
||||
|
||||
def get_analysis_pipeline(operation, table_parsing_skip_pages_without_images):
|
||||
if operation == "table":
|
||||
return make_analysis_pipeline(
|
||||
parse_tables,
|
||||
table_parsing_formatter,
|
||||
dpi=200,
|
||||
skip_pages_without_images=table_parsing_skip_pages_without_images,
|
||||
)
|
||||
elif operation == "figure":
|
||||
return make_analysis_pipeline(detect_figures, figure_detection_formatter, dpi=200)
|
||||
def make_analysis_pipeline_for_element_type(segment_type, **kwargs):
|
||||
if segment_type == "table":
|
||||
return make_analysis_pipeline(parse_tables, table_parsing_formatter, dpi=200, **kwargs)
|
||||
elif segment_type == "figure":
|
||||
return make_analysis_pipeline(detect_figures, figure_detection_formatter, dpi=200, **kwargs)
|
||||
else:
|
||||
raise
|
||||
raise ValueError(f"Unknown segment type {segment_type}.")
|
||||
|
||||
|
||||
def make_analysis_pipeline(analysis_fn, formatter, dpi, skip_pages_without_images=False):
|
||||
def analyse_pipeline(pdf: bytes, index=None):
|
||||
def analysis_pipeline(pdf: bytes, index=None):
|
||||
def parse_page(page: ImagePlus):
|
||||
image = page.asarray()
|
||||
rects = analysis_fn(image)
|
||||
if not rects:
|
||||
rectangles = analysis_fn(image)
|
||||
if not rectangles:
|
||||
return
|
||||
infos = formatter(rects, page, dpi)
|
||||
infos = formatter(rectangles, page, dpi)
|
||||
return infos
|
||||
|
||||
pages = convert_pages_to_images(pdf, index=index, dpi=dpi, skip_pages_without_images=skip_pages_without_images)
|
||||
@ -40,22 +35,26 @@ def make_analysis_pipeline(analysis_fn, formatter, dpi, skip_pages_without_image
|
||||
|
||||
yield from flatten(filter(truth, results))
|
||||
|
||||
return analyse_pipeline
|
||||
return analysis_pipeline
|
||||
|
||||
|
||||
def table_parsing_formatter(rects, page: ImagePlus, dpi):
|
||||
def format_rect(rect: Rectangle):
|
||||
rect_plus = RectanglePlus.from_pixels(*rect.xyxy(), page.info, alpha=False, dpi=dpi)
|
||||
return rect_plus.asdict(derotate=True)
|
||||
def table_parsing_formatter(rectangles, page: ImagePlus, dpi):
|
||||
def format_rectangle(rectangle: Rectangle):
|
||||
rectangle_plus = RectanglePlus.from_pixels(*rectangle_to_xyxy(rectangle), page.info, alpha=False, dpi=dpi)
|
||||
return rectangle_plus.asdict(derotate=True)
|
||||
|
||||
bboxes = lmap(format_rect, rects)
|
||||
bboxes = lmap(format_rectangle, rectangles)
|
||||
|
||||
return {"pageInfo": page.asdict(natural_index=True), "tableCells": bboxes}
|
||||
|
||||
|
||||
def figure_detection_formatter(rects, page, dpi):
|
||||
def format_rect(rect: Rectangle):
|
||||
rect_plus = RectanglePlus.from_pixels(*rect.xyxy(), page.info, alpha=False, dpi=dpi)
|
||||
def figure_detection_formatter(rectangles, page, dpi):
|
||||
def format_rectangle(rectangle: Rectangle):
|
||||
rect_plus = RectanglePlus.from_pixels(*rectangle_to_xyxy(rectangle), page.info, alpha=False, dpi=dpi)
|
||||
return asdict(ImageInfo(page.info, rect_plus.asbbox(derotate=False), rect_plus.alpha))
|
||||
|
||||
return lmap(format_rect, rects)
|
||||
return lmap(format_rectangle, rectangles)
|
||||
|
||||
|
||||
def rectangle_to_xyxy(rectangle: Rectangle):
|
||||
return rectangle.x1, rectangle.y1, rectangle.x2, rectangle.y2
|
||||
|
||||
@ -48,38 +48,17 @@ class Rectangle:
|
||||
def h(self):
|
||||
return self.__h
|
||||
|
||||
def __str__(self):
|
||||
return dumps(self.to_dict())
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __iter__(self):
|
||||
return list(self.to_dict().values()).__iter__()
|
||||
|
||||
def __eq__(self, other: Rectangle):
|
||||
return all([self.x1 == other.x1, self.y1 == other.y1, self.w == other.w, self.h == other.h])
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.x1, self.y1, self.x2, self.y2))
|
||||
|
||||
@classmethod
|
||||
def from_xywh(cls, xywh: Iterable[Coord], discrete=True):
|
||||
"""Creates a rectangle from a point and a width and height."""
|
||||
"""Creates a rectangle from a point, width and height."""
|
||||
x1, y1, w, h = xywh
|
||||
x2 = x1 + w
|
||||
y2 = y1 + h
|
||||
return cls(x1, y1, x2, y2, discrete=discrete)
|
||||
|
||||
def xyxy(self):
|
||||
return self.x1, self.y1, self.x2, self.y2
|
||||
|
||||
def xywh(self):
|
||||
return self.x1, self.y1, self.w, self.h
|
||||
|
||||
def to_dict(self):
|
||||
return {"x": self.x1, "y": self.y1, "width": self.w, "height": self.h}
|
||||
|
||||
def intersection(self, other):
|
||||
"""Calculates the intersection of this and another rectangle."""
|
||||
return intersection(self, other)
|
||||
|
||||
@ -2,28 +2,27 @@ import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from cv_analysis.server.pipeline import get_analysis_pipeline
|
||||
from loguru import logger
|
||||
|
||||
from cv_analysis.server.pipeline import make_analysis_pipeline_for_element_type
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("pdf")
|
||||
parser.add_argument("--type", "-t", choices=["table", "layout", "figure"], required=True)
|
||||
parser.add_argument("pdf", type=Path)
|
||||
parser.add_argument("--element_type", "-t", choices=["table", "figure"], required=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
analysis_fn = make_analysis_pipeline_for_element_type(args.element_type)
|
||||
|
||||
logger.info(f"Analysing document for {args.element_type}s...")
|
||||
results = list(analysis_fn(args.pdf.read_bytes()))
|
||||
|
||||
print(json.dumps(results, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
analysis_fn = get_analysis_pipeline(args.type)
|
||||
|
||||
with open(args.pdf, "rb") as f:
|
||||
pdf_bytes = f.read()
|
||||
|
||||
results = list(analysis_fn(pdf_bytes))
|
||||
|
||||
folder = Path(args.pdf).parent
|
||||
file_stem = Path(args.pdf).stem
|
||||
|
||||
with open(f"{folder}/{file_stem}_{args.type}.json", "w+") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
main(parse_args())
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
from operator import itemgetter
|
||||
|
||||
from cv_analysis.config import get_config
|
||||
from cv_analysis.server.pipeline import get_analysis_pipeline
|
||||
from cv_analysis.server.pipeline import make_analysis_pipeline_for_segment_type
|
||||
from cv_analysis.utils.banner import make_art
|
||||
from pyinfra import config as pyinfra_config
|
||||
from pyinfra.queue.queue_manager import QueueManager
|
||||
@ -31,7 +31,10 @@ def analysis_callback(queue_message: dict):
|
||||
should_publish_result = True
|
||||
|
||||
object_bytes = gzip.decompress(storage.get_object(bucket, object_name))
|
||||
analysis_fn = get_analysis_pipeline(operation, CV_CONFIG.table_parsing_skip_pages_without_images)
|
||||
analysis_fn = make_analysis_pipeline_for_segment_type(
|
||||
operation,
|
||||
skip_pages_without_images=CV_CONFIG.table_parsing_skip_pages_without_images,
|
||||
)
|
||||
|
||||
results = analysis_fn(object_bytes)
|
||||
response = {**queue_message, "data": list(results)}
|
||||
|
||||
@ -2,8 +2,11 @@ from itertools import starmap
|
||||
|
||||
import cv2
|
||||
import pytest
|
||||
from funcy import lmap, compose
|
||||
|
||||
from cv_analysis.table_parsing import parse_tables
|
||||
from cv_analysis.utils import lift
|
||||
from cv_analysis.utils.structures import Rectangle
|
||||
from cv_analysis.utils.test_metrics import compute_document_score
|
||||
|
||||
|
||||
@ -12,8 +15,9 @@ from cv_analysis.utils.test_metrics import compute_document_score
|
||||
def test_table_parsing_on_client_pages(
|
||||
score_threshold, client_page_with_table, expected_table_annotation, test_file_index
|
||||
):
|
||||
result = [rect.to_dict() for rect in parse_tables(client_page_with_table)]
|
||||
formatted_result = {"pages": [{"page": str(test_file_index), "cells": result}]}
|
||||
|
||||
results = compose(lift(rectangle_to_dict), parse_tables)(client_page_with_table)
|
||||
formatted_result = {"pages": [{"cells": results}]}
|
||||
|
||||
score = compute_document_score(formatted_result, expected_table_annotation)
|
||||
|
||||
@ -25,6 +29,14 @@ def error_tolerance(line_thickness):
|
||||
return line_thickness * 7
|
||||
|
||||
|
||||
def rectangle_to_dict(rectangle: Rectangle):
|
||||
return {"x": rectangle.x1, "y": rectangle.y1, "width": rectangle.w, "height": rectangle.h}
|
||||
|
||||
|
||||
def rectangle_to_xywh(rectangle: Rectangle):
|
||||
return rectangle.x1, rectangle.y1, rectangle.w, rectangle.h
|
||||
|
||||
|
||||
@pytest.mark.parametrize("line_thickness", [1, 2, 3])
|
||||
@pytest.mark.parametrize("line_type", [cv2.LINE_4, cv2.LINE_AA, cv2.LINE_8])
|
||||
@pytest.mark.parametrize("table_style", ["closed horizontal vertical", "open horizontal vertical"])
|
||||
@ -32,7 +44,7 @@ def error_tolerance(line_thickness):
|
||||
@pytest.mark.parametrize("background_color", [255, 220])
|
||||
@pytest.mark.parametrize("table_shape", [(5, 8)])
|
||||
def test_table_parsing_on_generic_pages(page_with_table, expected_gold_page_with_table, error_tolerance):
|
||||
result = [x.xywh() for x in parse_tables(page_with_table)]
|
||||
result = lmap(rectangle_to_xywh, parse_tables(page_with_table))
|
||||
assert (
|
||||
result == expected_gold_page_with_table
|
||||
or average_error(result, expected_gold_page_with_table) <= error_tolerance
|
||||
@ -46,8 +58,8 @@ def test_table_parsing_on_generic_pages(page_with_table, expected_gold_page_with
|
||||
@pytest.mark.parametrize("background_color", [255, 220])
|
||||
@pytest.mark.parametrize("table_shape", [(5, 8)])
|
||||
@pytest.mark.xfail
|
||||
def test_bad_qual_table(page_with_patchy_table, expected_gold_page_with_table, error_tolerance):
|
||||
result = [x.xywh() for x in parse_tables(page_with_patchy_table)]
|
||||
def test_low_quality_table(page_with_patchy_table, expected_gold_page_with_table, error_tolerance):
|
||||
result = lmap(rectangle_to_xywh, parse_tables(page_with_patchy_table))
|
||||
assert (
|
||||
result == expected_gold_page_with_table
|
||||
or average_error(result, expected_gold_page_with_table) <= error_tolerance
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user