Refactoring

This commit is contained in:
Matthias Bisping 2023-02-06 13:22:33 +01:00
parent 112e18ebb5
commit 0cf8e047c5
6 changed files with 56 additions and 9 deletions

View File

@ -1,7 +1,21 @@
import numpy as np
import pytest
from dvc.repo import Repo
from image_prediction.locations import PACKAGE_ROOT_DIR, TEST_DATA_DIR_DVC
from image_prediction.utils import get_logger
logger = get_logger()
@pytest.fixture
def input_batch(batch_size, input_size):
return np.random.random_sample(size=(batch_size, *input_size))
@pytest.fixture(scope="session")
def dvc_test_data():
logger.info("Pulling data with DVC...")
# noinspection PyCallingNonCallable
Repo(PACKAGE_ROOT_DIR).pull(targets=[str(TEST_DATA_DIR_DVC)])
logger.info("Finished pulling data.")

12
test/fixtures/pdf.py vendored
View File

@ -4,7 +4,7 @@ import fpdf
import pytest
from image_prediction.locations import TEST_DATA_DIR
from test.utils.generation.pdf import add_image, pdf_stream
from test.utils.generation.pdf import add_image, pdf_stream, stream_pdf_bytes
@pytest.fixture
@ -18,6 +18,10 @@ def pdf(image_metadata_pairs):
@pytest.fixture
def real_pdf():
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f:
yield f.read()
def real_pdf(dvc_test_data):
yield from stream_pdf_bytes(TEST_DATA_DIR / "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf")
@pytest.fixture
def bad_xref_pdf(dvc_test_data):
yield from stream_pdf_bytes(TEST_DATA_DIR / "bad_xref.pdf")

View File

@ -87,7 +87,7 @@ def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
@pytest.fixture
def real_expected_service_response():
def real_expected_service_response(dvc_test_data):
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
yield json.load(f)

View File

@ -1,3 +1,4 @@
import json
import random
from operator import itemgetter
@ -7,12 +8,23 @@ import pytest
from PIL import Image
from funcy import first, rest
from image_prediction.exceptions import BadXref
from image_prediction.extraction import extract_images_from_pdf
from image_prediction.formatter.formatters.enum import EnumFormatter
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.parsable import extract_pages, has_alpha_channel, get_image_infos
from image_prediction.image_extractor.extractors.parsable import (
extract_pages,
has_alpha_channel,
get_image_infos,
ParsablePDFImageExtractor,
extract_valid_metadata,
xref_to_maybe_image,
extract_image,
)
from image_prediction.info import Info
from image_prediction.locations import TEST_DATA_DIR
from test.utils.comparison import metadata_equal, image_sets_equal
from test.utils.generation.pdf import add_image, pdf_stream
from test.utils.generation.pdf import add_image, pdf_stream, stream_pdf_bytes
@pytest.mark.parametrize("extractor_type", ["mock"])
@ -75,3 +87,15 @@ def test_has_alpha_channel(base_patch_metadata, suffix, mode):
assert not list(rest(xrefs))
doc.close()
def test_bad_xref_handling(bad_xref_pdf, dvc_test_data):
doc = fitz.Document(stream=bad_xref_pdf)
metadata = extract_valid_metadata(doc, first(doc))
xref = first(metadata)[Info.XREF]
with pytest.raises(BadXref):
extract_image(doc, xref)
assert xref_to_maybe_image(doc, xref) is None

View File

@ -60,10 +60,10 @@ def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_pa
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
def test_image_stitcher_with_gaps_must_succeed():
def test_image_stitcher_with_gaps_must_succeed(dvc_test_data):
from image_prediction.locations import TEST_DATA_DIR
with open(os.path.join(TEST_DATA_DIR, "stitching_with_tolerance.json")) as f:
with open(TEST_DATA_DIR / "stitching_with_tolerance.json") as f:
patches_metadata, base_patch_metadata = itemgetter("input", "target")(ReverseEnumFormatter(Info)(json.load(f)))
images = map(gray_image_from_metadata, patches_metadata)

View File

@ -28,3 +28,8 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair, suffix):
with tempfile.NamedTemporaryFile(suffix=f".{suffix}") as temp_image:
image.save(temp_image.name)
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type=suffix)
def stream_pdf_bytes(path: str):
with open(path, "rb") as f:
yield f.read()