alpha channel querying improved

This commit is contained in:
Matthias Bisping 2022-04-13 17:31:33 +02:00
parent 2cc52c4630
commit 7aee00cb49
3 changed files with 56 additions and 8 deletions

View File

@ -1,3 +1,4 @@
import atexit
import io import io
from functools import partial, lru_cache from functools import partial, lru_cache
from itertools import chain, starmap, filterfalse from itertools import chain, starmap, filterfalse
@ -90,6 +91,7 @@ def clear_caches():
get_image_infos.cache_clear() get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear() load_image_handle_from_xref.cache_clear()
get_images_on_page.cache_clear() get_images_on_page.cache_clear()
xref_to_image.cache_clear()
def validate_coords_and_passthrough(metadata): def validate_coords_and_passthrough(metadata):
@ -106,11 +108,18 @@ def load_image_handle_from_xref(doc, xref):
def has_alpha_channel(doc, xref): def has_alpha_channel(doc, xref):
maybe_image = load_image_handle_from_xref(doc, xref) maybe_image = load_image_handle_from_xref(doc, xref)
return doc.extract_image(maybe_image["smask"]) is not None if maybe_image else False maybe_smask = maybe_image["smask"] if maybe_image else None
if maybe_smask:
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
else:
return bool(fitz.Pixmap(doc, xref).alpha)
def xref_to_image(doc, xref): @lru_cache(maxsize=None)
def xref_to_image(doc, xref) -> Image:
maybe_image = load_image_handle_from_xref(doc, xref) maybe_image = load_image_handle_from_xref(doc, xref)
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
@ -152,3 +161,5 @@ def get_page_metadata(page):
rounder = rcompose(round, int) rounder = rcompose(round, int)
atexit.register(clear_caches)

View File

@ -337,11 +337,11 @@ def pdf(image_metadata_pairs):
return pdf_stream(pdf) return pdf_stream(pdf)
def add_image(pdf, image_metadata_pair): def add_image(pdf, image_metadata_pair, suffix="png"):
while fewer_pages_then_required(image_metadata_pair.metadata[Info.PAGE_IDX], pdf): while fewer_pages_then_required(image_metadata_pair.metadata[Info.PAGE_IDX], pdf):
pdf.add_page() pdf.add_page()
add_image_to_last_page(pdf, image_metadata_pair) add_image_to_last_page(pdf, image_metadata_pair, suffix=suffix)
def fewer_pages_then_required(page_idx, pdf): def fewer_pages_then_required(page_idx, pdf):
@ -352,13 +352,13 @@ def pdf_stream(pdf: fpdf.fpdf.FPDF):
return pdf.output(dest="S").encode("latin1") return pdf.output(dest="S").encode("latin1")
def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair): def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair, suffix):
image, metadata = image_metadata_pair image, metadata = image_metadata_pair
x, y, w, h = itemgetter(Info.X1, Info.Y1, Info.WIDTH, Info.HEIGHT)(metadata) x, y, w, h = itemgetter(Info.X1, Info.Y1, Info.WIDTH, Info.HEIGHT)(metadata)
with tempfile.NamedTemporaryFile(suffix=".png") as temp_image: with tempfile.NamedTemporaryFile(suffix=f".{suffix}") as temp_image:
image.save(temp_image.name) image.save(temp_image.name)
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type="png") pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type=suffix)
@pytest.fixture @pytest.fixture

View File

@ -1,12 +1,19 @@
import random import random
from operator import itemgetter
import fitz import fitz
import fpdf
import numpy as np import numpy as np
import pytest import pytest
from PIL import Image
from funcy import first, rest
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
from image_prediction.extraction import extract_images_from_pdf from image_prediction.extraction import extract_images_from_pdf
from image_prediction.image_extractor.extractors.parsable import extract_pages from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.parsable import extract_pages, get_image_infos, has_alpha_channel
from image_prediction.info import Info
from test.conftest import add_image, pdf_stream
@pytest.mark.parametrize("extractor_type", ["mock"]) @pytest.mark.parametrize("extractor_type", ["mock"])
@ -39,3 +46,33 @@ def test_extract_pages(pdf):
pages = list(extract_pages(doc, page_range)) pages = list(extract_pages(doc, page_range))
assert all((isinstance(p, fitz.Page) for p in pages)) assert all((isinstance(p, fitz.Page) for p in pages))
assert len(pages) == len(page_range) assert len(pages) == len(page_range)
@pytest.mark.parametrize("suffix", ["gif", "png", "jpeg"])
@pytest.mark.parametrize("mode", ["RGB", "RGBA"])
def test_has_alpha_channel(base_patch_metadata, suffix, mode):
mode = "RGB" if suffix == "jpeg" else mode
pdf = fpdf.FPDF(unit="pt")
image = Image.new(mode, itemgetter(Info.WIDTH, Info.HEIGHT)(base_patch_metadata), color=(10, 10, 10))
add_image(pdf, ImageMetadataPair(image, base_patch_metadata), suffix=suffix)
doc = fitz.Document(stream=pdf_stream(pdf))
page = first(doc)
xrefs = map(itemgetter("xref"), get_image_infos(page))
result = has_alpha_channel(doc, first(xrefs))
if mode == "RGBA":
assert result
if mode == "RGB":
assert not result
assert not list(rest(xrefs))
doc.close()