refactoring
This commit is contained in:
parent
4756b8c9bd
commit
2c908162f1
@ -36,6 +36,13 @@ def get_page_metadata(page):
|
|||||||
Info.PAGE_IDX: page.number,
|
Info.PAGE_IDX: page.number,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_pages(doc, page_range):
|
||||||
|
page_range = range(page_range.start + 1, page_range.stop + 1)
|
||||||
|
pages = map(doc.load_page, page_range)
|
||||||
|
return pages
|
||||||
|
|
||||||
|
|
||||||
class ParsablePDFImageExtractor(ImageExtractor):
|
class ParsablePDFImageExtractor(ImageExtractor):
|
||||||
def __init__(self, verbose=False):
|
def __init__(self, verbose=False):
|
||||||
self.doc: fitz.fitz.Document = None
|
self.doc: fitz.fitz.Document = None
|
||||||
@ -44,10 +51,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
|||||||
def __process_images_on_page(self, page: fitz.fitz.Page):
|
def __process_images_on_page(self, page: fitz.fitz.Page):
|
||||||
def load_image_from_xref(xref):
|
def load_image_from_xref(xref):
|
||||||
maybe_image = self.doc.extract_image(xref)
|
maybe_image = self.doc.extract_image(xref)
|
||||||
if maybe_image:
|
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
|
||||||
return Image.open(io.BytesIO(maybe_image["image"]))
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
image_infos = page.get_image_info(xrefs=True)
|
image_infos = page.get_image_info(xrefs=True)
|
||||||
xrefs = map(itemgetter("xref"), image_infos)
|
xrefs = map(itemgetter("xref"), image_infos)
|
||||||
@ -60,12 +64,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
|||||||
def extract(self, pdf: bytes, page_range: range = None):
|
def extract(self, pdf: bytes, page_range: range = None):
|
||||||
self.doc = fitz.Document(stream=pdf)
|
self.doc = fitz.Document(stream=pdf)
|
||||||
|
|
||||||
if page_range:
|
pages = extract_pages(self.doc, page_range) if page_range else self.doc
|
||||||
page_range = range(page_range.start + 1, page_range.stop + 1)
|
|
||||||
doc = fitz.Document(stream=pdf)
|
|
||||||
pages = map(doc.load_page, page_range)
|
|
||||||
else:
|
|
||||||
pages = self.doc
|
|
||||||
|
|
||||||
image_metadata_pairs = chain.from_iterable(
|
image_metadata_pairs = chain.from_iterable(
|
||||||
map(
|
map(
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
import fitz
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
|
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
|
||||||
from image_prediction.extraction import extract_images_from_pdf
|
from image_prediction.extraction import extract_images_from_pdf
|
||||||
|
from image_prediction.image_extractor.extractors.parsable import extract_pages
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("extractor_type", ["mock"])
|
@pytest.mark.parametrize("extractor_type", ["mock"])
|
||||||
@ -18,3 +22,18 @@ def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, in
|
|||||||
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
|
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
|
||||||
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))
|
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))
|
||||||
assert list(metadata_extracted) == metadata
|
assert list(metadata_extracted) == metadata
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 2, 16])
|
||||||
|
def test_extract_pages(pdf):
|
||||||
|
doc = fitz.Document(stream=pdf)
|
||||||
|
|
||||||
|
max_index = max(0, doc.page_count - 1)
|
||||||
|
i = random.randint(0, max(0, max_index - 1))
|
||||||
|
j = random.randint(i + 1, max_index) if max_index > 0 else 0
|
||||||
|
|
||||||
|
page_range = range(i, j)
|
||||||
|
|
||||||
|
pages = list(extract_pages(doc, page_range))
|
||||||
|
assert all((isinstance(p, fitz.Page) for p in pages))
|
||||||
|
assert len(pages) == len(page_range)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user