From bbafad556107581118e221be90bfde3efd36b3c5 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Tue, 12 Apr 2022 18:22:38 +0200 Subject: [PATCH] refactoring in preparationfor alpha channel info --- .../image_extractor/extractors/parsable.py | 51 +++++++++++++++++-- image_prediction/info.py | 1 + scripts/run_pipeline.py | 2 +- test/conftest.py | 12 +++-- test/unit_tests/image_extractor_test.py | 6 ++- 5 files changed, 61 insertions(+), 11 deletions(-) diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 9cd04e5..ce24f40 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -1,11 +1,11 @@ import io from functools import partial, lru_cache -from itertools import chain, starmap, filterfalse +from itertools import chain, starmap, filterfalse, repeat from operator import itemgetter, truth import fitz from PIL import Image -from funcy import rcompose, compose, curry, merge +from funcy import rcompose, compose, curry, merge, zipdict from tqdm import tqdm from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair @@ -44,6 +44,7 @@ class ParsablePDFImageExtractor(ImageExtractor): images = get_images_on_page(self.doc, page) metadata = get_metadata_for_images_on_page(page) get_image_infos.cache_clear() + load_image_handle_from_xref.cache_clear() image_metadata_pairs = starmap( ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata)) @@ -63,7 +64,7 @@ def extract_pages(doc, page_range): def get_images_on_page(doc, page: fitz.Page): image_infos = get_image_infos(page) xrefs = map(itemgetter("xref"), image_infos) - images = map(partial(load_image_from_xref, doc), xrefs) + images = map(partial(xref_to_image, doc), xrefs) return images @@ -76,6 +77,11 @@ def get_metadata_for_images_on_page(page: fitz.Page): metadata = validate_size_and_passthrough(metadata) metadata = map(partial(merge, get_page_metadata(page)), metadata) + # xrefs = map(itemgetter("xref"), image_infos) + # alpha = map(has_alpha_channel, xrefs) + # alpha = zipdict(repeat(Info.ALPHA), alpha) + # metadata = starmap(merge, zip(alpha, metadata)) + yield from metadata @@ -87,8 +93,39 @@ def validate_size_and_passthrough(metadata): yield from map(validate_box_size, metadata) -def load_image_from_xref(doc, xref): - maybe_image = doc.extract_image(xref) +# def load_image_from_xref(doc, xref): +# +# maybe_image = doc.extract_image(xref) +# if maybe_image: +# smask = doc.extract_image(maybe_image["smask"]) +# pix1 = fitz.Pixmap(maybe_image) # (1) pixmap of image w/o alpha +# mask = fitz.Pixmap(doc.extract_image(smask)["image"]) # (2) mask pixmap +# pix = fitz.Pixmap(pix1, mask) # (3) copy of pix1, image mask added +# im = Image.open(io.BytesIO(pix.tobytes())) +# else: +# im = None +# +# import IPython +# IPython.embed() +# return im + + +@lru_cache(maxsize=None) +def load_image_handle_from_xref(doc, xref): + return doc.extract_image(xref) + + +def has_alpha_channel(doc, xref): + maybe_image = load_image_handle_from_xref(doc, xref) + if maybe_image: + maybe_smask = doc.extract_image(maybe_image["smask"]) + return maybe_smask is not None + else: + return False + + +def xref_to_image(doc, xref): + maybe_image = load_image_handle_from_xref(doc, xref) return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None @@ -98,6 +135,10 @@ def get_image_infos(page: fitz.Page): def get_image_metadata(image_info): + + # import IPython + # IPython.embed() + # smask = doc.extract_image(maybe_image["smask"]) x1, y1, x2, y2 = map(rounder, image_info["bbox"]) width = abs(x2 - x1) diff --git a/image_prediction/info.py b/image_prediction/info.py index e9083ee..8483e09 100644 --- a/image_prediction/info.py +++ b/image_prediction/info.py @@ -11,3 +11,4 @@ class Info(Enum): X2 = "x2" Y1 = "y1" Y2 = "y2" + # ALPHA = "alpha" diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index c2b4bb0..d7bf253 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -35,7 +35,7 @@ def process_pdf(pipeline, pdf_path, page_range=None): def main(args): - pipeline = load_pipeline(verbose=True, tolerance=3) + pipeline = load_pipeline(verbose=False, tolerance=3) if os.path.isfile(args.input): pdf_paths = [args.input] diff --git a/test/conftest.py b/test/conftest.py index 1a18203..a0e6168 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -194,12 +194,16 @@ def input_size(request): def array_to_image(array): assert np.all(array <= 1) assert np.all(array >= 0) - return Image.fromarray(np.uint8(array * 255), mode="RGB") + if array.shape[-1] == 3: + mode = "RGB" + elif array.shape[-1] == 4: + mode = "RGBA" + else: + raise ValueError(f"Unexpected number of channels {array.shape[-1]}. Expected 3 or 4.") -@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}]) -def input_size(request): - return itemgetter("width", "height", "depth")(request.param) + # noinspection PyTypeChecker + return Image.fromarray(np.uint8(array * 255), mode=mode) @pytest.fixture diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py index 3c72354..2f40d7d 100644 --- a/test/unit_tests/image_extractor_test.py +++ b/test/unit_tests/image_extractor_test.py @@ -17,7 +17,11 @@ def test_image_extractor_mock(image_extractor, images): @pytest.mark.parametrize("extractor_type", ["parsable_pdf", "default"]) -@pytest.mark.parametrize("input_size", [{"depth": 3, "width": 170, "height": 220}], indirect=["input_size"]) +@pytest.mark.parametrize( + "input_size", + [{"depth": 3, "width": 170, "height": 220}, {"depth": 3, "width": 170, "height": 220}], + indirect=["input_size"], +) def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size): images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor)) assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))