refactoring

This commit is contained in:
Matthias Bisping 2022-04-05 16:31:57 +02:00
parent 4756b8c9bd
commit 2c908162f1
2 changed files with 28 additions and 10 deletions

View File

@ -36,6 +36,13 @@ def get_page_metadata(page):
Info.PAGE_IDX: page.number, Info.PAGE_IDX: page.number,
} }
def extract_pages(doc, page_range):
page_range = range(page_range.start + 1, page_range.stop + 1)
pages = map(doc.load_page, page_range)
return pages
class ParsablePDFImageExtractor(ImageExtractor): class ParsablePDFImageExtractor(ImageExtractor):
def __init__(self, verbose=False): def __init__(self, verbose=False):
self.doc: fitz.fitz.Document = None self.doc: fitz.fitz.Document = None
@ -44,10 +51,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
def __process_images_on_page(self, page: fitz.fitz.Page): def __process_images_on_page(self, page: fitz.fitz.Page):
def load_image_from_xref(xref): def load_image_from_xref(xref):
maybe_image = self.doc.extract_image(xref) maybe_image = self.doc.extract_image(xref)
if maybe_image: return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
return Image.open(io.BytesIO(maybe_image["image"]))
else:
return None
image_infos = page.get_image_info(xrefs=True) image_infos = page.get_image_info(xrefs=True)
xrefs = map(itemgetter("xref"), image_infos) xrefs = map(itemgetter("xref"), image_infos)
@ -60,12 +64,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
def extract(self, pdf: bytes, page_range: range = None): def extract(self, pdf: bytes, page_range: range = None):
self.doc = fitz.Document(stream=pdf) self.doc = fitz.Document(stream=pdf)
if page_range: pages = extract_pages(self.doc, page_range) if page_range else self.doc
page_range = range(page_range.start + 1, page_range.stop + 1)
doc = fitz.Document(stream=pdf)
pages = map(doc.load_page, page_range)
else:
pages = self.doc
image_metadata_pairs = chain.from_iterable( image_metadata_pairs = chain.from_iterable(
map( map(

View File

@ -1,8 +1,12 @@
import random
import fitz
import numpy as np import numpy as np
import pytest import pytest
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
from image_prediction.extraction import extract_images_from_pdf from image_prediction.extraction import extract_images_from_pdf
from image_prediction.image_extractor.extractors.parsable import extract_pages
@pytest.mark.parametrize("extractor_type", ["mock"]) @pytest.mark.parametrize("extractor_type", ["mock"])
@ -18,3 +22,18 @@ def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, in
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor)) images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images)) assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))
assert list(metadata_extracted) == metadata assert list(metadata_extracted) == metadata
@pytest.mark.parametrize("batch_size", [1, 2, 16])
def test_extract_pages(pdf):
doc = fitz.Document(stream=pdf)
max_index = max(0, doc.page_count - 1)
i = random.randint(0, max(0, max_index - 1))
j = random.randint(i + 1, max_index) if max_index > 0 else 0
page_range = range(i, j)
pages = list(extract_pages(doc, page_range))
assert all((isinstance(p, fitz.Page) for p in pages))
assert len(pages) == len(page_range)