Refactoring
This commit is contained in:
parent
112e18ebb5
commit
0cf8e047c5
14
test/fixtures/input.py
vendored
14
test/fixtures/input.py
vendored
@ -1,7 +1,21 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from dvc.repo import Repo
|
||||
|
||||
from image_prediction.locations import PACKAGE_ROOT_DIR, TEST_DATA_DIR_DVC
|
||||
from image_prediction.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def input_batch(batch_size, input_size):
|
||||
return np.random.random_sample(size=(batch_size, *input_size))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def dvc_test_data():
|
||||
logger.info("Pulling data with DVC...")
|
||||
# noinspection PyCallingNonCallable
|
||||
Repo(PACKAGE_ROOT_DIR).pull(targets=[str(TEST_DATA_DIR_DVC)])
|
||||
logger.info("Finished pulling data.")
|
||||
|
||||
12
test/fixtures/pdf.py
vendored
12
test/fixtures/pdf.py
vendored
@ -4,7 +4,7 @@ import fpdf
|
||||
import pytest
|
||||
|
||||
from image_prediction.locations import TEST_DATA_DIR
|
||||
from test.utils.generation.pdf import add_image, pdf_stream
|
||||
from test.utils.generation.pdf import add_image, pdf_stream, stream_pdf_bytes
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -18,6 +18,10 @@ def pdf(image_metadata_pairs):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_pdf():
|
||||
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f:
|
||||
yield f.read()
|
||||
def real_pdf(dvc_test_data):
|
||||
yield from stream_pdf_bytes(TEST_DATA_DIR / "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bad_xref_pdf(dvc_test_data):
|
||||
yield from stream_pdf_bytes(TEST_DATA_DIR / "bad_xref.pdf")
|
||||
|
||||
2
test/fixtures/target.py
vendored
2
test/fixtures/target.py
vendored
@ -87,7 +87,7 @@ def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_expected_service_response():
|
||||
def real_expected_service_response(dvc_test_data):
|
||||
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
|
||||
yield json.load(f)
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import random
|
||||
from operator import itemgetter
|
||||
|
||||
@ -7,12 +8,23 @@ import pytest
|
||||
from PIL import Image
|
||||
from funcy import first, rest
|
||||
|
||||
from image_prediction.exceptions import BadXref
|
||||
from image_prediction.extraction import extract_images_from_pdf
|
||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.image_extractor.extractors.parsable import extract_pages, has_alpha_channel, get_image_infos
|
||||
from image_prediction.image_extractor.extractors.parsable import (
|
||||
extract_pages,
|
||||
has_alpha_channel,
|
||||
get_image_infos,
|
||||
ParsablePDFImageExtractor,
|
||||
extract_valid_metadata,
|
||||
xref_to_maybe_image,
|
||||
extract_image,
|
||||
)
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.locations import TEST_DATA_DIR
|
||||
from test.utils.comparison import metadata_equal, image_sets_equal
|
||||
from test.utils.generation.pdf import add_image, pdf_stream
|
||||
from test.utils.generation.pdf import add_image, pdf_stream, stream_pdf_bytes
|
||||
|
||||
|
||||
@pytest.mark.parametrize("extractor_type", ["mock"])
|
||||
@ -75,3 +87,15 @@ def test_has_alpha_channel(base_patch_metadata, suffix, mode):
|
||||
assert not list(rest(xrefs))
|
||||
|
||||
doc.close()
|
||||
|
||||
|
||||
def test_bad_xref_handling(bad_xref_pdf, dvc_test_data):
|
||||
|
||||
doc = fitz.Document(stream=bad_xref_pdf)
|
||||
metadata = extract_valid_metadata(doc, first(doc))
|
||||
xref = first(metadata)[Info.XREF]
|
||||
|
||||
with pytest.raises(BadXref):
|
||||
extract_image(doc, xref)
|
||||
|
||||
assert xref_to_maybe_image(doc, xref) is None
|
||||
|
||||
@ -60,10 +60,10 @@ def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_pa
|
||||
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
||||
|
||||
|
||||
def test_image_stitcher_with_gaps_must_succeed():
|
||||
def test_image_stitcher_with_gaps_must_succeed(dvc_test_data):
|
||||
from image_prediction.locations import TEST_DATA_DIR
|
||||
|
||||
with open(os.path.join(TEST_DATA_DIR, "stitching_with_tolerance.json")) as f:
|
||||
with open(TEST_DATA_DIR / "stitching_with_tolerance.json") as f:
|
||||
patches_metadata, base_patch_metadata = itemgetter("input", "target")(ReverseEnumFormatter(Info)(json.load(f)))
|
||||
|
||||
images = map(gray_image_from_metadata, patches_metadata)
|
||||
|
||||
@ -28,3 +28,8 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair, suffix):
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{suffix}") as temp_image:
|
||||
image.save(temp_image.name)
|
||||
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type=suffix)
|
||||
|
||||
|
||||
def stream_pdf_bytes(path: str):
|
||||
with open(path, "rb") as f:
|
||||
yield f.read()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user