refactoring: splitting conftest logic into submodules

This commit is contained in:
Matthias Bisping 2022-04-14 18:36:13 +02:00
parent 128b325c0d
commit 3a3ab81223
5 changed files with 64 additions and 56 deletions

View File

@ -16,18 +16,20 @@ from image_prediction.exceptions import (
)
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.label_mapper.mappers.numeric import IndexMapper
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper, ProbabilityMapperKeys
from image_prediction.label_mapper.mappers.probability import ProbabilityMapperKeys
from image_prediction.locations import TEST_DATA_DIR
from image_prediction.pipeline import load_pipeline
from image_prediction.utils import get_logger
from test.utils.generation.pdf import add_image, pdf_stream
from test.utils.label import map_labels
pytest_plugins = [
'test.fixtures.model',
'test.fixtures.model_store',
'test.fixtures.image',
'test.fixtures.input',
'test.fixtures.parameters',
'test.fixtures.label',
]
@ -41,21 +43,6 @@ def mute_logger():
logger.setLevel(level)
@pytest.fixture
def label_mapper(label_format, classes):
if label_format == "index":
return IndexMapper(classes)
elif label_format == "probability":
return ProbabilityMapper(classes)
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture(params=["index"])
def label_format(request):
return request.param
@pytest.fixture
def expected_predictions_mapped(
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
@ -119,15 +106,6 @@ def output_batch_generator(expected_predictions):
return iter(expected_predictions)
@pytest.fixture
def classes():
return ["A", "B", "C"]
def map_labels(numeric_labels, classes):
return [classes[nl] for nl in numeric_labels]
@pytest.fixture
def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata):
return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
@ -243,11 +221,3 @@ def base_patch_metadata(width, height, page_width, page_height):
return metadata
@pytest.fixture(params=[220, 30])
def page_height(request):
return request.param
@pytest.fixture(params=[100, 310])
def page_width(request):
return request.param

View File

@ -1,5 +1,3 @@
from operator import itemgetter
import pytest
from test.utils.generation.image import array_to_image
@ -14,23 +12,3 @@ def images(input_batch):
def input_size(alpha, __input_size):
w, h, d = __input_size
return w, h, d + alpha
@pytest.fixture(params=[False])
def alpha(request):
return request.param
@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}])
def __input_size(request):
return itemgetter("width", "height", "depth")(request.param)
@pytest.fixture(params=[33, 100])
def height(request):
return request.param
@pytest.fixture(params=[10, 31])
def width(request):
return request.param

25
test/fixtures/label.py vendored Normal file
View File

@ -0,0 +1,25 @@
import pytest
from image_prediction.exceptions import UnknownLabelFormat
from image_prediction.label_mapper.mappers.numeric import IndexMapper
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
@pytest.fixture
def label_mapper(label_format, classes):
if label_format == "index":
return IndexMapper(classes)
elif label_format == "probability":
return ProbabilityMapper(classes)
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture(params=["index"])
def label_format(request):
return request.param
@pytest.fixture
def classes():
return ["A", "B", "C"]

View File

@ -0,0 +1,33 @@
from operator import itemgetter
import pytest
@pytest.fixture(params=[220, 30])
def page_height(request):
return request.param
@pytest.fixture(params=[100, 310])
def page_width(request):
return request.param
@pytest.fixture(params=[False])
def alpha(request):
return request.param
@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}])
def __input_size(request):
return itemgetter("width", "height", "depth")(request.param)
@pytest.fixture(params=[33, 100])
def height(request):
return request.param
@pytest.fixture(params=[10, 31])
def width(request):
return request.param

2
test/utils/label.py Normal file
View File

@ -0,0 +1,2 @@
def map_labels(numeric_labels, classes):
return [classes[nl] for nl in numeric_labels]