Also sets image stitching tolerance default to one (pixel) and adds informative log of which settings are loaded when initializing the image classification pipeline.
57 lines
1.7 KiB
Python
57 lines
1.7 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
from glob import glob
|
|
|
|
from image_prediction.config import CONFIG
|
|
from image_prediction.pipeline import load_pipeline
|
|
from image_prediction.utils import get_logger
|
|
from image_prediction.utils.pdf_annotation import annotate_pdf
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("input", help="pdf file or directory")
|
|
parser.add_argument("--print", "-p", help="print output to terminal", action="store_true", default=False)
|
|
parser.add_argument("--page_interval", "-i", help="page interval [i, j), min index = 0", nargs=2, type=int)
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def process_pdf(pipeline, pdf_path, page_range=None):
|
|
with open(pdf_path, "rb") as f:
|
|
logger.info(f"Processing {pdf_path}")
|
|
predictions = list(pipeline(f.read(), page_range=page_range))
|
|
|
|
annotate_pdf(
|
|
pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", "_annotated.pdf")))
|
|
)
|
|
|
|
return predictions
|
|
|
|
|
|
def main(args):
|
|
pipeline = load_pipeline(verbose=CONFIG.service.verbose, batch_size=CONFIG.service.batch_size, tolerance=CONFIG.service.image_stiching_tolerance)
|
|
|
|
if os.path.isfile(args.input):
|
|
pdf_paths = [args.input]
|
|
else:
|
|
pdf_paths = glob(os.path.join(args.input, "*.pdf"))
|
|
page_range = range(*args.page_interval) if args.page_interval else None
|
|
|
|
for pdf_path in pdf_paths:
|
|
predictions = process_pdf(pipeline, pdf_path, page_range=page_range)
|
|
if args.print:
|
|
print(pdf_path)
|
|
print(json.dumps(predictions, indent=2))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
main(args)
|