From 7aee00cb49d8d3d30c8241b1839af446a42b756e Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Wed, 13 Apr 2022 17:31:33 +0200 Subject: [PATCH] alpha channel querying improved --- .../image_extractor/extractors/parsable.py | 15 ++++++- test/conftest.py | 10 ++--- test/unit_tests/image_extractor_test.py | 39 ++++++++++++++++++- 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 1a8e1ef..9de0884 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -1,3 +1,4 @@ +import atexit import io from functools import partial, lru_cache from itertools import chain, starmap, filterfalse @@ -90,6 +91,7 @@ def clear_caches(): get_image_infos.cache_clear() load_image_handle_from_xref.cache_clear() get_images_on_page.cache_clear() + xref_to_image.cache_clear() def validate_coords_and_passthrough(metadata): @@ -106,11 +108,18 @@ def load_image_handle_from_xref(doc, xref): def has_alpha_channel(doc, xref): + maybe_image = load_image_handle_from_xref(doc, xref) - return doc.extract_image(maybe_image["smask"]) is not None if maybe_image else False + maybe_smask = maybe_image["smask"] if maybe_image else None + + if maybe_smask: + return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)]) + else: + return bool(fitz.Pixmap(doc, xref).alpha) -def xref_to_image(doc, xref): +@lru_cache(maxsize=None) +def xref_to_image(doc, xref) -> Image: maybe_image = load_image_handle_from_xref(doc, xref) return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None @@ -152,3 +161,5 @@ def get_page_metadata(page): rounder = rcompose(round, int) + +atexit.register(clear_caches) diff --git a/test/conftest.py b/test/conftest.py index 5400ff5..a919c38 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -337,11 +337,11 @@ def pdf(image_metadata_pairs): return pdf_stream(pdf) -def add_image(pdf, image_metadata_pair): +def add_image(pdf, image_metadata_pair, suffix="png"): while fewer_pages_then_required(image_metadata_pair.metadata[Info.PAGE_IDX], pdf): pdf.add_page() - add_image_to_last_page(pdf, image_metadata_pair) + add_image_to_last_page(pdf, image_metadata_pair, suffix=suffix) def fewer_pages_then_required(page_idx, pdf): @@ -352,13 +352,13 @@ def pdf_stream(pdf: fpdf.fpdf.FPDF): return pdf.output(dest="S").encode("latin1") -def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair): +def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair, suffix): image, metadata = image_metadata_pair x, y, w, h = itemgetter(Info.X1, Info.Y1, Info.WIDTH, Info.HEIGHT)(metadata) - with tempfile.NamedTemporaryFile(suffix=".png") as temp_image: + with tempfile.NamedTemporaryFile(suffix=f".{suffix}") as temp_image: image.save(temp_image.name) - pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type="png") + pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type=suffix) @pytest.fixture diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py index 4548fe1..a6c6078 100644 --- a/test/unit_tests/image_extractor_test.py +++ b/test/unit_tests/image_extractor_test.py @@ -1,12 +1,19 @@ import random +from operator import itemgetter import fitz +import fpdf import numpy as np import pytest +from PIL import Image +from funcy import first, rest from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor from image_prediction.extraction import extract_images_from_pdf -from image_prediction.image_extractor.extractors.parsable import extract_pages +from image_prediction.image_extractor.extractor import ImageMetadataPair +from image_prediction.image_extractor.extractors.parsable import extract_pages, get_image_infos, has_alpha_channel +from image_prediction.info import Info +from test.conftest import add_image, pdf_stream @pytest.mark.parametrize("extractor_type", ["mock"]) @@ -39,3 +46,33 @@ def test_extract_pages(pdf): pages = list(extract_pages(doc, page_range)) assert all((isinstance(p, fitz.Page) for p in pages)) assert len(pages) == len(page_range) + + +@pytest.mark.parametrize("suffix", ["gif", "png", "jpeg"]) +@pytest.mark.parametrize("mode", ["RGB", "RGBA"]) +def test_has_alpha_channel(base_patch_metadata, suffix, mode): + + mode = "RGB" if suffix == "jpeg" else mode + + pdf = fpdf.FPDF(unit="pt") + + image = Image.new(mode, itemgetter(Info.WIDTH, Info.HEIGHT)(base_patch_metadata), color=(10, 10, 10)) + + add_image(pdf, ImageMetadataPair(image, base_patch_metadata), suffix=suffix) + + doc = fitz.Document(stream=pdf_stream(pdf)) + + page = first(doc) + + xrefs = map(itemgetter("xref"), get_image_infos(page)) + + result = has_alpha_channel(doc, first(xrefs)) + + if mode == "RGBA": + assert result + if mode == "RGB": + assert not result + + assert not list(rest(xrefs)) + + doc.close()