diff --git a/test/fixtures/input.py b/test/fixtures/input.py index b02f414..2054df6 100644 --- a/test/fixtures/input.py +++ b/test/fixtures/input.py @@ -1,7 +1,21 @@ import numpy as np import pytest +from dvc.repo import Repo + +from image_prediction.locations import PACKAGE_ROOT_DIR, TEST_DATA_DIR_DVC +from image_prediction.utils import get_logger + +logger = get_logger() @pytest.fixture def input_batch(batch_size, input_size): return np.random.random_sample(size=(batch_size, *input_size)) + + +@pytest.fixture(scope="session") +def dvc_test_data(): + logger.info("Pulling data with DVC...") + # noinspection PyCallingNonCallable + Repo(PACKAGE_ROOT_DIR).pull(targets=[str(TEST_DATA_DIR_DVC)]) + logger.info("Finished pulling data.") diff --git a/test/fixtures/pdf.py b/test/fixtures/pdf.py index 7353917..0991bbe 100644 --- a/test/fixtures/pdf.py +++ b/test/fixtures/pdf.py @@ -4,7 +4,7 @@ import fpdf import pytest from image_prediction.locations import TEST_DATA_DIR -from test.utils.generation.pdf import add_image, pdf_stream +from test.utils.generation.pdf import add_image, pdf_stream, stream_pdf_bytes @pytest.fixture @@ -18,6 +18,10 @@ def pdf(image_metadata_pairs): @pytest.fixture -def real_pdf(): - with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f: - yield f.read() +def real_pdf(dvc_test_data): + yield from stream_pdf_bytes(TEST_DATA_DIR / "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf") + + +@pytest.fixture +def bad_xref_pdf(dvc_test_data): + yield from stream_pdf_bytes(TEST_DATA_DIR / "bad_xref.pdf") diff --git a/test/fixtures/target.py b/test/fixtures/target.py index 23f23bd..1f111fc 100644 --- a/test/fixtures/target.py +++ b/test/fixtures/target.py @@ -87,7 +87,7 @@ def expected_predictions_mapped_and_formatted(expected_predictions_mapped): @pytest.fixture -def real_expected_service_response(): +def real_expected_service_response(dvc_test_data): with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f: yield json.load(f) diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py index 8e6916c..fa006e9 100644 --- a/test/unit_tests/image_extractor_test.py +++ b/test/unit_tests/image_extractor_test.py @@ -1,3 +1,4 @@ +import json import random from operator import itemgetter @@ -7,12 +8,23 @@ import pytest from PIL import Image from funcy import first, rest +from image_prediction.exceptions import BadXref from image_prediction.extraction import extract_images_from_pdf +from image_prediction.formatter.formatters.enum import EnumFormatter from image_prediction.image_extractor.extractor import ImageMetadataPair -from image_prediction.image_extractor.extractors.parsable import extract_pages, has_alpha_channel, get_image_infos +from image_prediction.image_extractor.extractors.parsable import ( + extract_pages, + has_alpha_channel, + get_image_infos, + ParsablePDFImageExtractor, + extract_valid_metadata, + xref_to_maybe_image, + extract_image, +) from image_prediction.info import Info +from image_prediction.locations import TEST_DATA_DIR from test.utils.comparison import metadata_equal, image_sets_equal -from test.utils.generation.pdf import add_image, pdf_stream +from test.utils.generation.pdf import add_image, pdf_stream, stream_pdf_bytes @pytest.mark.parametrize("extractor_type", ["mock"]) @@ -75,3 +87,15 @@ def test_has_alpha_channel(base_patch_metadata, suffix, mode): assert not list(rest(xrefs)) doc.close() + + +def test_bad_xref_handling(bad_xref_pdf, dvc_test_data): + + doc = fitz.Document(stream=bad_xref_pdf) + metadata = extract_valid_metadata(doc, first(doc)) + xref = first(metadata)[Info.XREF] + + with pytest.raises(BadXref): + extract_image(doc, xref) + + assert xref_to_maybe_image(doc, xref) is None diff --git a/test/unit_tests/image_stitching_test.py b/test/unit_tests/image_stitching_test.py index edf7923..3762036 100644 --- a/test/unit_tests/image_stitching_test.py +++ b/test/unit_tests/image_stitching_test.py @@ -60,10 +60,10 @@ def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_pa assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4) -def test_image_stitcher_with_gaps_must_succeed(): +def test_image_stitcher_with_gaps_must_succeed(dvc_test_data): from image_prediction.locations import TEST_DATA_DIR - with open(os.path.join(TEST_DATA_DIR, "stitching_with_tolerance.json")) as f: + with open(TEST_DATA_DIR / "stitching_with_tolerance.json") as f: patches_metadata, base_patch_metadata = itemgetter("input", "target")(ReverseEnumFormatter(Info)(json.load(f))) images = map(gray_image_from_metadata, patches_metadata) diff --git a/test/utils/generation/pdf.py b/test/utils/generation/pdf.py index 852647e..111a6d4 100644 --- a/test/utils/generation/pdf.py +++ b/test/utils/generation/pdf.py @@ -28,3 +28,8 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair, suffix): with tempfile.NamedTemporaryFile(suffix=f".{suffix}") as temp_image: image.save(temp_image.name) pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type=suffix) + + +def stream_pdf_bytes(path: str): + with open(path, "rb") as f: + yield f.read()