From 2c908162f1e6fc96fd212074e3b50a6f1c035403 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Tue, 5 Apr 2022 16:31:57 +0200 Subject: [PATCH] refactoring --- .../image_extractor/extractors/parsable.py | 19 +++++++++---------- test/unit_tests/image_extractor_test.py | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index b70b0d9..9c2e10a 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -36,6 +36,13 @@ def get_page_metadata(page): Info.PAGE_IDX: page.number, } + +def extract_pages(doc, page_range): + page_range = range(page_range.start + 1, page_range.stop + 1) + pages = map(doc.load_page, page_range) + return pages + + class ParsablePDFImageExtractor(ImageExtractor): def __init__(self, verbose=False): self.doc: fitz.fitz.Document = None @@ -44,10 +51,7 @@ class ParsablePDFImageExtractor(ImageExtractor): def __process_images_on_page(self, page: fitz.fitz.Page): def load_image_from_xref(xref): maybe_image = self.doc.extract_image(xref) - if maybe_image: - return Image.open(io.BytesIO(maybe_image["image"])) - else: - return None + return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None image_infos = page.get_image_info(xrefs=True) xrefs = map(itemgetter("xref"), image_infos) @@ -60,12 +64,7 @@ class ParsablePDFImageExtractor(ImageExtractor): def extract(self, pdf: bytes, page_range: range = None): self.doc = fitz.Document(stream=pdf) - if page_range: - page_range = range(page_range.start + 1, page_range.stop + 1) - doc = fitz.Document(stream=pdf) - pages = map(doc.load_page, page_range) - else: - pages = self.doc + pages = extract_pages(self.doc, page_range) if page_range else self.doc image_metadata_pairs = chain.from_iterable( map( diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py index a6cdb3a..3c72354 100644 --- a/test/unit_tests/image_extractor_test.py +++ b/test/unit_tests/image_extractor_test.py @@ -1,8 +1,12 @@ +import random + +import fitz import numpy as np import pytest from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor from image_prediction.extraction import extract_images_from_pdf +from image_prediction.image_extractor.extractors.parsable import extract_pages @pytest.mark.parametrize("extractor_type", ["mock"]) @@ -18,3 +22,18 @@ def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, in images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor)) assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images)) assert list(metadata_extracted) == metadata + + +@pytest.mark.parametrize("batch_size", [1, 2, 16]) +def test_extract_pages(pdf): + doc = fitz.Document(stream=pdf) + + max_index = max(0, doc.page_count - 1) + i = random.randint(0, max(0, max_index - 1)) + j = random.randint(i + 1, max_index) if max_index > 0 else 0 + + page_range = range(i, j) + + pages = list(extract_pages(doc, page_range)) + assert all((isinstance(p, fitz.Page) for p in pages)) + assert len(pages) == len(page_range)