add ParsablePDFImageExtractor test

This commit is contained in:
Julius Unverfehrt 2022-03-28 15:42:54 +02:00
parent 2631eb5c0f
commit 9461be29d5
4 changed files with 71 additions and 25 deletions

View File

@ -1,7 +1,8 @@
from itertools import chain from itertools import chain, starmap
from operator import itemgetter from operator import itemgetter
import fitz import fitz
from funcy import rcompose
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
@ -9,22 +10,37 @@ from image_prediction.image_extractor.extractor import ImageExtractor, ImageMeta
class ParsablePDFImageExtractor(ImageExtractor): class ParsablePDFImageExtractor(ImageExtractor):
def __init__(self): def __init__(self):
self.doc: fitz.fitz.Document = None self.doc: fitz.fitz.Document = None
self.page: fitz.fitz.Page = None
def __build_metadata(self, xref):
metadata = self.page.get_image_info(xref)
page_width, page_height = self.page.mediabox_size
metadata = {**metadata, "page_width": page_width, "page_height": page_height, "page_idx": self.page.number}
return metadata
def __process_images_on_page(self, page: fitz.fitz.Page): def __process_images_on_page(self, page: fitz.fitz.Page):
self.page = page def load_image_from_xref(xref):
return self.doc.extract_image(xref)["image"]
def format_metadata(image_info):
x1, y1, x2, y2 = map(rounder, image_info["bbox"])
width, height = itemgetter("width", "height")(image_info)
return {
"page_width": page_width,
"page_height": page_height,
"page_idx": page.number,
"width": width,
"height": height,
"x1": x1,
"x2": x2,
"y1": y1,
"y2": y2
}
rounder = rcompose(round, int)
page_width, page_height = map(rounder, page.mediabox_size)
image_handles = page.get_images(full=True) image_handles = page.get_images(full=True)
xrefs = itemgetter(0)(image_handles) xrefs = map(itemgetter(0), image_handles)
images = map(lambda xref: self.doc.extract_image(xref)["image"], xrefs) images = map(load_image_from_xref, xrefs)
metadata = map(self.__build_metadata, xrefs) image_infos = page.get_image_info()
return map(ImageMetadataPair, zip(images, metadata)) metadata = map(format_metadata, image_infos)
return starmap(ImageMetadataPair, zip(images, metadata))
def extract(self, pdf: bytes): def extract(self, pdf: bytes):
self.doc = fitz.Document(stream=pdf) self.doc = fitz.Document(stream=pdf)

View File

@ -22,3 +22,4 @@ scikit_learn~=0.24.2
pytest~=7.1.0 pytest~=7.1.0
funcy==1.17 funcy==1.17
PyMuPDF==1.19.6 PyMuPDF==1.19.6
fpdf==1.7.2

View File

@ -1,5 +1,8 @@
import random import random
import tempfile
from operator import itemgetter
import fpdf
import numpy as np import numpy as np
import pytest import pytest
from PIL import Image from PIL import Image
@ -126,19 +129,25 @@ def map_labels(numeric_labels, classes):
@pytest.fixture @pytest.fixture
def metadata(images): def metadata(images):
page_idx = 0
def current_page_idx():
nonlocal page_idx
page_idx += random.randint(0, 2)
return min(page_idx, len(images) - 1)
def build_image_metadata(image): def build_image_metadata(image):
width, height = image.size width, height = image.size
page_width = 595.32 page_width = 595
page_height = 842.04 page_height = 842
x1 = random.uniform(0, page_width - width) x1 = random.randint(0, page_width - width)
x2 = x1 + width x2 = x1 + width
y1 = random.uniform(0, page_height) y1 = random.randint(0, page_height - height)
y2 = y1 - height y2 = y1 + height
metadata = { metadata = {
"page_width": page_width, "page_width": page_width,
"page_height": page_height, "page_height": page_height,
"page_idx": 0, "page_idx": current_page_idx(),
"width": width, "width": width,
"height": height, "height": height,
"x1": x1, "x1": x1,
@ -146,5 +155,23 @@ def metadata(images):
"y1": y1, "y1": y1,
"y2": y2 "y2": y2
} }
return metadata
return map(build_image_metadata, images) return list(map(build_image_metadata, images))
@pytest.fixture
def pdf(images, metadata):
pdf = fpdf.FPDF(unit="pt")
pdf.add_page()
for image, metadata in zip(images, metadata):
while metadata["page_idx"] > pdf.page - 1:
pdf.add_page()
x, y, w, h = itemgetter("x1", "y1", "width", "height")(metadata)
with tempfile.NamedTemporaryFile(suffix=".png") as temp_image:
image.save(temp_image.name)
pdf.image(temp_image.name, x=x, y=y, w=w, h=h)
with tempfile.NamedTemporaryFile(suffix=".pdf") as temp_pdf:
pdf.output(temp_pdf.name)
with open(temp_pdf.name, "rb") as open_temp_pdf:
yield open_temp_pdf.read()

View File

@ -1,3 +1,5 @@
import time
import pytest import pytest
@ -9,8 +11,8 @@ def test_image_extractor_mock(image_extractor, images):
@pytest.mark.parametrize("extractor_type", ["parsable_pdf"]) @pytest.mark.parametrize("extractor_type", ["parsable_pdf"])
@pytest.mark.parametrize("batch_size", [1, 2, 4]) @pytest.mark.parametrize("batch_size", [10])
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata): def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata):
images_extracted, metadata_extracted = map(list, zip(*image_extractor(images))) images_extracted, metadata_extracted = map(list, zip(*image_extractor(pdf)))
assert images_extracted == images # assert images_extracted == images
assert metadata_extracted == metadata assert list(metadata_extracted) == metadata