From 8f61c4cba2f181c648b2c64c4d9c747b44ddbd8b Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Mon, 4 Apr 2022 21:49:45 +0200 Subject: [PATCH] doc.extract_image(xref) can yield None; hence added filtering for None images --- .../image_extractor/extractors/parsable.py | 13 +++++--- scripts/run_pipeline.py | 32 +++++++++++++------ 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index bab7fd0..0277a5a 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -1,10 +1,11 @@ import io from itertools import chain, starmap -from operator import itemgetter +from operator import itemgetter, __and__, truth import fitz from PIL import Image -from funcy import rcompose +from funcy import rcompose, compose, curry +from iteration_utilities import starfilter from tqdm import tqdm from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair @@ -18,7 +19,11 @@ class ParsablePDFImageExtractor(ImageExtractor): def __process_images_on_page(self, page: fitz.fitz.Page): def load_image_from_xref(xref): - return Image.open(io.BytesIO(self.doc.extract_image(xref)["image"])) + maybe_image = self.doc.extract_image(xref) + if maybe_image: + return Image.open(io.BytesIO(maybe_image["image"])) + else: + return None def format_metadata(image_info): x1, y1, x2, y2 = map(rounder, image_info["bbox"]) @@ -44,7 +49,7 @@ class ParsablePDFImageExtractor(ImageExtractor): images = map(load_image_from_xref, xrefs) metadata = map(format_metadata, image_infos) - return starmap(ImageMetadataPair, zip(images, metadata)) + return starmap(ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata))) def extract(self, pdf: bytes): self.doc = fitz.Document(stream=pdf) diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index f0d3be8..4d74413 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -1,35 +1,49 @@ import argparse import json import os +from glob import glob from image_prediction.pipeline import load_pipeline -from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer from image_prediction.utils.pdf_annotation import annotate_pdf def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("pdf") + + parser.add_argument("input", help="pdf file or directory") + + parser.add_argument("-print", "-p", help="print output to terminal", action="store_true", default=False) args = parser.parse_args() return args -def main(args): - pipeline = load_pipeline(verbose=True) - - pdf_path = args.pdf - +def process_pdf(pipeline, pdf_path): with open(pdf_path, "rb") as f: predictions = list(pipeline(f.read())) - print(json.dumps(predictions, indent=2)) - annotate_pdf( pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", "_annotated.pdf"))) ) + return predictions + + +def main(args): + pipeline = load_pipeline(verbose=True) + + if os.path.isfile(args.input): + pdf_paths = [args.input] + else: + pdf_paths = glob(os.path.join(args.input, "*.pdf")) + + for pdf_path in pdf_paths: + predictions = process_pdf(pipeline, pdf_path) + if args.print: + print(pdf_path) + print(json.dumps(predictions, indent=2)) + if __name__ == "__main__": args = parse_args()