from dataclasses import asdict from operator import truth from funcy import lmap, flatten from cv_analysis.figure_detection.figure_detection import detect_figures from cv_analysis.table_parsing import parse_tables from cv_analysis.utils.rectangle import Rectangle from pdf2img.conversion import convert_pages_to_images from pdf2img.default_objects.image import ImagePlus, ImageInfo from pdf2img.default_objects.rectangle import RectanglePlus def make_analysis_pipeline_for_element_type(segment_type, **kwargs): if segment_type == "table": return make_analysis_pipeline(parse_tables, table_parsing_formatter, dpi=200, **kwargs) elif segment_type == "figure": return make_analysis_pipeline(detect_figures, figure_detection_formatter, dpi=200, **kwargs) else: raise ValueError(f"Unknown segment type {segment_type}.") def make_analysis_pipeline(analysis_fn, formatter, dpi, skip_pages_without_images=False): def analysis_pipeline(pdf: bytes, index=None): def parse_page(page: ImagePlus): image = page.asarray() rectangles = analysis_fn(image) if not rectangles: return infos = formatter(rectangles, page, dpi) return infos pages = convert_pages_to_images(pdf, index=index, dpi=dpi, skip_pages_without_images=skip_pages_without_images) results = map(parse_page, pages) yield from flatten(filter(truth, results)) return analysis_pipeline def table_parsing_formatter(rectangles, page: ImagePlus, dpi): def format_rectangle(rectangle: Rectangle): rectangle_plus = RectanglePlus.from_pixels(*rectangle_to_xyxy(rectangle), page.info, alpha=False, dpi=dpi) return rectangle_plus.asdict(derotate=True) bboxes = lmap(format_rectangle, rectangles) return {"pageInfo": page.asdict(natural_index=True), "tableCells": bboxes} def figure_detection_formatter(rectangles, page, dpi): def format_rectangle(rectangle: Rectangle): rect_plus = RectanglePlus.from_pixels(*rectangle_to_xyxy(rectangle), page.info, alpha=False, dpi=dpi) return asdict(ImageInfo(page.info, rect_plus.asbbox(derotate=False), rect_plus.alpha)) return lmap(format_rectangle, rectangles) def rectangle_to_xyxy(rectangle: Rectangle): return rectangle.x1, rectangle.y1, rectangle.x2, rectangle.y2