From e0885c545aeb96d2eef9cf9cfa540fe457784c44 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Tue, 5 Apr 2022 13:03:17 +0200 Subject: [PATCH] added page range paramter to extractor --- .../extractor_classifier/extractor_classifier.py | 4 ++-- image_prediction/image_extractor/extractor.py | 4 ++-- image_prediction/pipeline.py | 4 ++-- scripts/run_pipeline.py | 11 ++++++----- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/image_prediction/extractor_classifier/extractor_classifier.py b/image_prediction/extractor_classifier/extractor_classifier.py index b08330a..95d217b 100644 --- a/image_prediction/extractor_classifier/extractor_classifier.py +++ b/image_prediction/extractor_classifier/extractor_classifier.py @@ -21,8 +21,8 @@ class ExtractorClassifier: responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata)) return responses - def __call__(self, obj) -> Iterable[ImageMetadataPair]: - image_metadata_pairs = self.extractor(obj) + def __call__(self, obj, **kwargs) -> Iterable[ImageMetadataPair]: + image_metadata_pairs = self.extractor(obj, **kwargs) batches = chunk_iterable(image_metadata_pairs, chunk_size=16) predictions = chain.from_iterable(map(self.__process_batch, batches)) return predictions diff --git a/image_prediction/image_extractor/extractor.py b/image_prediction/image_extractor/extractor.py index 8f1bfe6..ca6392e 100644 --- a/image_prediction/image_extractor/extractor.py +++ b/image_prediction/image_extractor/extractor.py @@ -14,6 +14,6 @@ class ImageExtractor(abc.ABC): def extract(self, obj) -> Iterable[ImageMetadataPair]: raise NotImplementedError - def __call__(self, obj): + def __call__(self, obj, **kwargs): logger.debug("ImageExtractor.extract") - return self.extract(obj) + return self.extract(obj, **kwargs) diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index a58119b..35ce8af 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -22,5 +22,5 @@ class Pipeline: def __init__(self, model_loader, model_identifier, **kwargs): self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier, **kwargs), get_formatter()) - def __call__(self, pdf: bytes): - yield from self.pipe(pdf) + def __call__(self, pdf: bytes, page_range: range = None): + yield from self.pipe(pdf, page_range=page_range) diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 81fa140..a44c572 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -14,18 +14,18 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("input", help="pdf file or directory") - - parser.add_argument("-print", "-p", help="print output to terminal", action="store_true", default=False) + parser.add_argument("--print", "-p", help="print output to terminal", action="store_true", default=False) + parser.add_argument("--page_interval", "-i", help="page interval [i, j), min index = 0", nargs=2, type=int) args = parser.parse_args() return args -def process_pdf(pipeline, pdf_path): +def process_pdf(pipeline, pdf_path, page_range=None): with open(pdf_path, "rb") as f: logger.info(f"Processing {pdf_path}") - predictions = list(pipeline(f.read())) + predictions = list(pipeline(f.read(), page_range=page_range)) annotate_pdf( pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", "_annotated.pdf"))) @@ -41,9 +41,10 @@ def main(args): pdf_paths = [args.input] else: pdf_paths = glob(os.path.join(args.input, "*.pdf")) + page_range = range(*args.page_interval) if args.page_interval else None for pdf_path in pdf_paths: - predictions = process_pdf(pipeline, pdf_path) + predictions = process_pdf(pipeline, pdf_path, page_range=page_range) if args.print: print(pdf_path) print(json.dumps(predictions, indent=2))