refactoring
This commit is contained in:
parent
4756b8c9bd
commit
2c908162f1
@ -36,6 +36,13 @@ def get_page_metadata(page):
|
||||
Info.PAGE_IDX: page.number,
|
||||
}
|
||||
|
||||
|
||||
def extract_pages(doc, page_range):
|
||||
page_range = range(page_range.start + 1, page_range.stop + 1)
|
||||
pages = map(doc.load_page, page_range)
|
||||
return pages
|
||||
|
||||
|
||||
class ParsablePDFImageExtractor(ImageExtractor):
|
||||
def __init__(self, verbose=False):
|
||||
self.doc: fitz.fitz.Document = None
|
||||
@ -44,10 +51,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
def __process_images_on_page(self, page: fitz.fitz.Page):
|
||||
def load_image_from_xref(xref):
|
||||
maybe_image = self.doc.extract_image(xref)
|
||||
if maybe_image:
|
||||
return Image.open(io.BytesIO(maybe_image["image"]))
|
||||
else:
|
||||
return None
|
||||
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
|
||||
|
||||
image_infos = page.get_image_info(xrefs=True)
|
||||
xrefs = map(itemgetter("xref"), image_infos)
|
||||
@ -60,12 +64,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
def extract(self, pdf: bytes, page_range: range = None):
|
||||
self.doc = fitz.Document(stream=pdf)
|
||||
|
||||
if page_range:
|
||||
page_range = range(page_range.start + 1, page_range.stop + 1)
|
||||
doc = fitz.Document(stream=pdf)
|
||||
pages = map(doc.load_page, page_range)
|
||||
else:
|
||||
pages = self.doc
|
||||
pages = extract_pages(self.doc, page_range) if page_range else self.doc
|
||||
|
||||
image_metadata_pairs = chain.from_iterable(
|
||||
map(
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
import random
|
||||
|
||||
import fitz
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
|
||||
from image_prediction.extraction import extract_images_from_pdf
|
||||
from image_prediction.image_extractor.extractors.parsable import extract_pages
|
||||
|
||||
|
||||
@pytest.mark.parametrize("extractor_type", ["mock"])
|
||||
@ -18,3 +22,18 @@ def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, in
|
||||
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
|
||||
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))
|
||||
assert list(metadata_extracted) == metadata
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 16])
|
||||
def test_extract_pages(pdf):
|
||||
doc = fitz.Document(stream=pdf)
|
||||
|
||||
max_index = max(0, doc.page_count - 1)
|
||||
i = random.randint(0, max(0, max_index - 1))
|
||||
j = random.randint(i + 1, max_index) if max_index > 0 else 0
|
||||
|
||||
page_range = range(i, j)
|
||||
|
||||
pages = list(extract_pages(doc, page_range))
|
||||
assert all((isinstance(p, fitz.Page) for p in pages))
|
||||
assert len(pages) == len(page_range)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user