From b818ee47245ce3d99ded12dda01d77ed73c53fe4 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Mon, 28 Mar 2022 16:38:46 +0200 Subject: [PATCH] fixed misaligned metadata and images --- image_prediction/extraction.py | 13 +++++++++++++ .../image_extractor/extractors/parsable.py | 9 +++++---- test/unit_tests/image_extractor_test.py | 12 +++++++----- 3 files changed, 25 insertions(+), 9 deletions(-) create mode 100644 image_prediction/extraction.py diff --git a/image_prediction/extraction.py b/image_prediction/extraction.py new file mode 100644 index 0000000..b996ed7 --- /dev/null +++ b/image_prediction/extraction.py @@ -0,0 +1,13 @@ +from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor + + +def extract_images_from_pdf(pdf, extractor=None): + + if not extractor: + extractor = ParsablePDFImageExtractor() + + try: + images_extracted, metadata_extracted = zip(*extractor(pdf)) + return images_extracted, metadata_extracted + except ValueError: + return [], [] diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 0853526..4ebbb6c 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -1,7 +1,9 @@ +import io from itertools import chain, starmap from operator import itemgetter import fitz +from PIL import Image from funcy import rcompose from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair @@ -13,7 +15,7 @@ class ParsablePDFImageExtractor(ImageExtractor): def __process_images_on_page(self, page: fitz.fitz.Page): def load_image_from_xref(xref): - return self.doc.extract_image(xref)["image"] + return Image.open(io.BytesIO(self.doc.extract_image(xref)["image"])) def format_metadata(image_info): x1, y1, x2, y2 = map(rounder, image_info["bbox"]) @@ -34,10 +36,9 @@ class ParsablePDFImageExtractor(ImageExtractor): page_width, page_height = map(rounder, page.mediabox_size) - image_handles = page.get_images(full=True) - xrefs = map(itemgetter(0), image_handles) + image_infos = page.get_image_info(xrefs=True) + xrefs = map(itemgetter("xref"), image_infos) images = map(load_image_from_xref, xrefs) - image_infos = page.get_image_info() metadata = map(format_metadata, image_infos) return starmap(ImageMetadataPair, zip(images, metadata)) diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py index 9166da9..9ad2c9f 100644 --- a/test/unit_tests/image_extractor_test.py +++ b/test/unit_tests/image_extractor_test.py @@ -1,7 +1,9 @@ -import time - +import numpy as np import pytest +from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor +from image_prediction.extraction import extract_images_from_pdf + @pytest.mark.parametrize("extractor_type", ["mock"]) @pytest.mark.parametrize("batch_size", [1, 2, 4]) @@ -11,8 +13,8 @@ def test_image_extractor_mock(image_extractor, images): @pytest.mark.parametrize("extractor_type", ["parsable_pdf"]) -@pytest.mark.parametrize("batch_size", [10]) +@pytest.mark.parametrize("batch_size", [0, 1, 2, 4, 8]) def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata): - images_extracted, metadata_extracted = map(list, zip(*image_extractor(pdf))) - # assert images_extracted == images + images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor)) + assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images)) assert list(metadata_extracted) == metadata