alpha channel querying improved
This commit is contained in:
parent
2cc52c4630
commit
7aee00cb49
@ -1,3 +1,4 @@
|
|||||||
|
import atexit
|
||||||
import io
|
import io
|
||||||
from functools import partial, lru_cache
|
from functools import partial, lru_cache
|
||||||
from itertools import chain, starmap, filterfalse
|
from itertools import chain, starmap, filterfalse
|
||||||
@ -90,6 +91,7 @@ def clear_caches():
|
|||||||
get_image_infos.cache_clear()
|
get_image_infos.cache_clear()
|
||||||
load_image_handle_from_xref.cache_clear()
|
load_image_handle_from_xref.cache_clear()
|
||||||
get_images_on_page.cache_clear()
|
get_images_on_page.cache_clear()
|
||||||
|
xref_to_image.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
def validate_coords_and_passthrough(metadata):
|
def validate_coords_and_passthrough(metadata):
|
||||||
@ -106,11 +108,18 @@ def load_image_handle_from_xref(doc, xref):
|
|||||||
|
|
||||||
|
|
||||||
def has_alpha_channel(doc, xref):
|
def has_alpha_channel(doc, xref):
|
||||||
|
|
||||||
maybe_image = load_image_handle_from_xref(doc, xref)
|
maybe_image = load_image_handle_from_xref(doc, xref)
|
||||||
return doc.extract_image(maybe_image["smask"]) is not None if maybe_image else False
|
maybe_smask = maybe_image["smask"] if maybe_image else None
|
||||||
|
|
||||||
|
if maybe_smask:
|
||||||
|
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
|
||||||
|
else:
|
||||||
|
return bool(fitz.Pixmap(doc, xref).alpha)
|
||||||
|
|
||||||
|
|
||||||
def xref_to_image(doc, xref):
|
@lru_cache(maxsize=None)
|
||||||
|
def xref_to_image(doc, xref) -> Image:
|
||||||
maybe_image = load_image_handle_from_xref(doc, xref)
|
maybe_image = load_image_handle_from_xref(doc, xref)
|
||||||
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
|
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
|
||||||
|
|
||||||
@ -152,3 +161,5 @@ def get_page_metadata(page):
|
|||||||
|
|
||||||
|
|
||||||
rounder = rcompose(round, int)
|
rounder = rcompose(round, int)
|
||||||
|
|
||||||
|
atexit.register(clear_caches)
|
||||||
|
|||||||
@ -337,11 +337,11 @@ def pdf(image_metadata_pairs):
|
|||||||
return pdf_stream(pdf)
|
return pdf_stream(pdf)
|
||||||
|
|
||||||
|
|
||||||
def add_image(pdf, image_metadata_pair):
|
def add_image(pdf, image_metadata_pair, suffix="png"):
|
||||||
while fewer_pages_then_required(image_metadata_pair.metadata[Info.PAGE_IDX], pdf):
|
while fewer_pages_then_required(image_metadata_pair.metadata[Info.PAGE_IDX], pdf):
|
||||||
pdf.add_page()
|
pdf.add_page()
|
||||||
|
|
||||||
add_image_to_last_page(pdf, image_metadata_pair)
|
add_image_to_last_page(pdf, image_metadata_pair, suffix=suffix)
|
||||||
|
|
||||||
|
|
||||||
def fewer_pages_then_required(page_idx, pdf):
|
def fewer_pages_then_required(page_idx, pdf):
|
||||||
@ -352,13 +352,13 @@ def pdf_stream(pdf: fpdf.fpdf.FPDF):
|
|||||||
return pdf.output(dest="S").encode("latin1")
|
return pdf.output(dest="S").encode("latin1")
|
||||||
|
|
||||||
|
|
||||||
def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair):
|
def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair, suffix):
|
||||||
image, metadata = image_metadata_pair
|
image, metadata = image_metadata_pair
|
||||||
x, y, w, h = itemgetter(Info.X1, Info.Y1, Info.WIDTH, Info.HEIGHT)(metadata)
|
x, y, w, h = itemgetter(Info.X1, Info.Y1, Info.WIDTH, Info.HEIGHT)(metadata)
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".png") as temp_image:
|
with tempfile.NamedTemporaryFile(suffix=f".{suffix}") as temp_image:
|
||||||
image.save(temp_image.name)
|
image.save(temp_image.name)
|
||||||
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type="png")
|
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type=suffix)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -1,12 +1,19 @@
|
|||||||
import random
|
import random
|
||||||
|
from operator import itemgetter
|
||||||
|
|
||||||
import fitz
|
import fitz
|
||||||
|
import fpdf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
from funcy import first, rest
|
||||||
|
|
||||||
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
|
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
|
||||||
from image_prediction.extraction import extract_images_from_pdf
|
from image_prediction.extraction import extract_images_from_pdf
|
||||||
from image_prediction.image_extractor.extractors.parsable import extract_pages
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
|
from image_prediction.image_extractor.extractors.parsable import extract_pages, get_image_infos, has_alpha_channel
|
||||||
|
from image_prediction.info import Info
|
||||||
|
from test.conftest import add_image, pdf_stream
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("extractor_type", ["mock"])
|
@pytest.mark.parametrize("extractor_type", ["mock"])
|
||||||
@ -39,3 +46,33 @@ def test_extract_pages(pdf):
|
|||||||
pages = list(extract_pages(doc, page_range))
|
pages = list(extract_pages(doc, page_range))
|
||||||
assert all((isinstance(p, fitz.Page) for p in pages))
|
assert all((isinstance(p, fitz.Page) for p in pages))
|
||||||
assert len(pages) == len(page_range)
|
assert len(pages) == len(page_range)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("suffix", ["gif", "png", "jpeg"])
|
||||||
|
@pytest.mark.parametrize("mode", ["RGB", "RGBA"])
|
||||||
|
def test_has_alpha_channel(base_patch_metadata, suffix, mode):
|
||||||
|
|
||||||
|
mode = "RGB" if suffix == "jpeg" else mode
|
||||||
|
|
||||||
|
pdf = fpdf.FPDF(unit="pt")
|
||||||
|
|
||||||
|
image = Image.new(mode, itemgetter(Info.WIDTH, Info.HEIGHT)(base_patch_metadata), color=(10, 10, 10))
|
||||||
|
|
||||||
|
add_image(pdf, ImageMetadataPair(image, base_patch_metadata), suffix=suffix)
|
||||||
|
|
||||||
|
doc = fitz.Document(stream=pdf_stream(pdf))
|
||||||
|
|
||||||
|
page = first(doc)
|
||||||
|
|
||||||
|
xrefs = map(itemgetter("xref"), get_image_infos(page))
|
||||||
|
|
||||||
|
result = has_alpha_channel(doc, first(xrefs))
|
||||||
|
|
||||||
|
if mode == "RGBA":
|
||||||
|
assert result
|
||||||
|
if mode == "RGB":
|
||||||
|
assert not result
|
||||||
|
|
||||||
|
assert not list(rest(xrefs))
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user