added page range paramter to extractor

This commit is contained in:
Matthias Bisping 2022-04-05 13:03:17 +02:00
parent fdb7ebe618
commit e0885c545a
4 changed files with 12 additions and 11 deletions

View File

@ -21,8 +21,8 @@ class ExtractorClassifier:
responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata)) responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
return responses return responses
def __call__(self, obj) -> Iterable[ImageMetadataPair]: def __call__(self, obj, **kwargs) -> Iterable[ImageMetadataPair]:
image_metadata_pairs = self.extractor(obj) image_metadata_pairs = self.extractor(obj, **kwargs)
batches = chunk_iterable(image_metadata_pairs, chunk_size=16) batches = chunk_iterable(image_metadata_pairs, chunk_size=16)
predictions = chain.from_iterable(map(self.__process_batch, batches)) predictions = chain.from_iterable(map(self.__process_batch, batches))
return predictions return predictions

View File

@ -14,6 +14,6 @@ class ImageExtractor(abc.ABC):
def extract(self, obj) -> Iterable[ImageMetadataPair]: def extract(self, obj) -> Iterable[ImageMetadataPair]:
raise NotImplementedError raise NotImplementedError
def __call__(self, obj): def __call__(self, obj, **kwargs):
logger.debug("ImageExtractor.extract") logger.debug("ImageExtractor.extract")
return self.extract(obj) return self.extract(obj, **kwargs)

View File

@ -22,5 +22,5 @@ class Pipeline:
def __init__(self, model_loader, model_identifier, **kwargs): def __init__(self, model_loader, model_identifier, **kwargs):
self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier, **kwargs), get_formatter()) self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier, **kwargs), get_formatter())
def __call__(self, pdf: bytes): def __call__(self, pdf: bytes, page_range: range = None):
yield from self.pipe(pdf) yield from self.pipe(pdf, page_range=page_range)

View File

@ -14,18 +14,18 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("input", help="pdf file or directory") parser.add_argument("input", help="pdf file or directory")
parser.add_argument("--print", "-p", help="print output to terminal", action="store_true", default=False)
parser.add_argument("-print", "-p", help="print output to terminal", action="store_true", default=False) parser.add_argument("--page_interval", "-i", help="page interval [i, j), min index = 0", nargs=2, type=int)
args = parser.parse_args() args = parser.parse_args()
return args return args
def process_pdf(pipeline, pdf_path): def process_pdf(pipeline, pdf_path, page_range=None):
with open(pdf_path, "rb") as f: with open(pdf_path, "rb") as f:
logger.info(f"Processing {pdf_path}") logger.info(f"Processing {pdf_path}")
predictions = list(pipeline(f.read())) predictions = list(pipeline(f.read(), page_range=page_range))
annotate_pdf( annotate_pdf(
pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", "_annotated.pdf"))) pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", "_annotated.pdf")))
@ -41,9 +41,10 @@ def main(args):
pdf_paths = [args.input] pdf_paths = [args.input]
else: else:
pdf_paths = glob(os.path.join(args.input, "*.pdf")) pdf_paths = glob(os.path.join(args.input, "*.pdf"))
page_range = range(*args.page_interval) if args.page_interval else None
for pdf_path in pdf_paths: for pdf_path in pdf_paths:
predictions = process_pdf(pipeline, pdf_path) predictions = process_pdf(pipeline, pdf_path, page_range=page_range)
if args.print: if args.print:
print(pdf_path) print(pdf_path)
print(json.dumps(predictions, indent=2)) print(json.dumps(predictions, indent=2))