From 62bfedfea81238034c8ffa1f522a21cf94a00743 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Wed, 13 Apr 2022 12:06:55 +0200 Subject: [PATCH] alpha channel test fix --- .../image_extractor/extractors/parsable.py | 50 +++++++------------ test/conftest.py | 17 ++++--- test/unit_tests/image_extractor_test.py | 12 ++--- 3 files changed, 31 insertions(+), 48 deletions(-) diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 77bc024..9f684a3 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -1,11 +1,12 @@ import io from functools import partial, lru_cache -from itertools import chain, starmap, filterfalse, repeat -from operator import itemgetter, truth +from itertools import chain, starmap, filterfalse +from operator import itemgetter +from typing import List import fitz from PIL import Image -from funcy import rcompose, compose, curry, merge, zipdict +from funcy import rcompose, merge from tqdm import tqdm from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair @@ -43,12 +44,9 @@ class ParsablePDFImageExtractor(ImageExtractor): def __process_images_on_page(self, page: fitz.fitz.Page): images = get_images_on_page(self.doc, page) metadata = get_metadata_for_images_on_page(self.doc, page) - get_image_infos.cache_clear() - load_image_handle_from_xref.cache_clear() + clear_caches() - image_metadata_pairs = starmap( - ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata)) - ) + image_metadata_pairs = starmap(ImageMetadataPair, filter(all, zip(images, metadata))) image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance) yield from image_metadata_pairs @@ -61,6 +59,7 @@ def extract_pages(doc, page_range): return pages +@lru_cache(maxsize=None) def get_images_on_page(doc, page: fitz.Page): image_infos = get_image_infos(page) xrefs = map(itemgetter("xref"), image_infos) @@ -70,9 +69,8 @@ def get_images_on_page(doc, page: fitz.Page): def get_metadata_for_images_on_page(doc, page: fitz.Page): - image_infos = get_image_infos(page) - metadata = map(get_image_metadata, image_infos) + metadata = map(get_image_metadata, get_image_infos(page)) metadata = validate_coords_and_passthrough(metadata) metadata = filterfalse(tiny, metadata) @@ -80,14 +78,20 @@ def get_metadata_for_images_on_page(doc, page: fitz.Page): metadata = map(partial(merge, get_page_metadata(page)), metadata) - xrefs = map(itemgetter("xref"), image_infos) + xrefs = map(itemgetter("xref"), get_image_infos(page)) alpha = map(partial(has_alpha_channel, doc), xrefs) alpha = ({Info.ALPHA: a} for a in alpha) - metadata = starmap(merge, zip(alpha, metadata)) + metadata = list(starmap(merge, zip(alpha, metadata))) yield from metadata +def clear_caches(): + get_image_infos.cache_clear() + load_image_handle_from_xref.cache_clear() + get_images_on_page.cache_clear() + + def validate_coords_and_passthrough(metadata): yield from map(validate_box_coords, metadata) @@ -96,23 +100,6 @@ def validate_size_and_passthrough(metadata): yield from map(validate_box_size, metadata) -# def load_image_from_xref(doc, xref): -# -# maybe_image = doc.extract_image(xref) -# if maybe_image: -# smask = doc.extract_image(maybe_image["smask"]) -# pix1 = fitz.Pixmap(maybe_image) # (1) pixmap of image w/o alpha -# mask = fitz.Pixmap(doc.extract_image(smask)["image"]) # (2) mask pixmap -# pix = fitz.Pixmap(pix1, mask) # (3) copy of pix1, image mask added -# im = Image.open(io.BytesIO(pix.tobytes())) -# else: -# im = None -# -# import IPython -# IPython.embed() -# return im - - @lru_cache(maxsize=None) def load_image_handle_from_xref(doc, xref): return doc.extract_image(xref) @@ -133,15 +120,12 @@ def xref_to_image(doc, xref): @lru_cache(maxsize=None) -def get_image_infos(page: fitz.Page): +def get_image_infos(page: fitz.Page) -> List[dict]: return page.get_image_info(xrefs=True) def get_image_metadata(image_info): - # import IPython - # IPython.embed() - # smask = doc.extract_image(maybe_image["smask"]) x1, y1, x2, y2 = map(rounder, image_info["bbox"]) width = abs(x2 - x1) diff --git a/test/conftest.py b/test/conftest.py index 42559e4..a78c366 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -186,14 +186,15 @@ def batch_size(request): return request.param +@pytest.fixture +def input_size(alpha, __input_size): + w, h, d = __input_size + return w, h, d + alpha + + @pytest.fixture(params=[False]) -def input_size(request, __input_size): - alpha = request.param - print(alpha) - if alpha: - w, h, d = __input_size - __input_size = w, h, d + 1 - return __input_size +def alpha(request): + return request.param @pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}]) @@ -301,7 +302,7 @@ def metadata(images, info_label_map): info_label_map.X2: x2, info_label_map.Y1: y1, info_label_map.Y2: y2, - info_label_map.ALPHA: image.mode == "RGBA" + info_label_map.ALPHA: image.mode == "RGBA", } return metadata diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py index 2f40d7d..4548fe1 100644 --- a/test/unit_tests/image_extractor_test.py +++ b/test/unit_tests/image_extractor_test.py @@ -17,14 +17,12 @@ def test_image_extractor_mock(image_extractor, images): @pytest.mark.parametrize("extractor_type", ["parsable_pdf", "default"]) -@pytest.mark.parametrize( - "input_size", - [{"depth": 3, "width": 170, "height": 220}, {"depth": 3, "width": 170, "height": 220}], - indirect=["input_size"], -) -def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size): +@pytest.mark.parametrize("input_size", [{"depth": 3, "width": 170, "height": 220}], indirect=["input_size"]) +@pytest.mark.parametrize("alpha", [False, True]) +def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size, alpha): images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor)) - assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images)) + if not alpha: + assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images)) assert list(metadata_extracted) == metadata