diff --git a/test/conftest.py b/test/conftest.py index af16f30..7461849 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -16,18 +16,20 @@ from image_prediction.exceptions import ( ) from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info -from image_prediction.label_mapper.mappers.numeric import IndexMapper -from image_prediction.label_mapper.mappers.probability import ProbabilityMapper, ProbabilityMapperKeys +from image_prediction.label_mapper.mappers.probability import ProbabilityMapperKeys from image_prediction.locations import TEST_DATA_DIR from image_prediction.pipeline import load_pipeline from image_prediction.utils import get_logger from test.utils.generation.pdf import add_image, pdf_stream +from test.utils.label import map_labels pytest_plugins = [ 'test.fixtures.model', 'test.fixtures.model_store', 'test.fixtures.image', 'test.fixtures.input', + 'test.fixtures.parameters', + 'test.fixtures.label', ] @@ -41,21 +43,6 @@ def mute_logger(): logger.setLevel(level) -@pytest.fixture -def label_mapper(label_format, classes): - if label_format == "index": - return IndexMapper(classes) - elif label_format == "probability": - return ProbabilityMapper(classes) - else: - raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.") - - -@pytest.fixture(params=["index"]) -def label_format(request): - return request.param - - @pytest.fixture def expected_predictions_mapped( label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings @@ -119,15 +106,6 @@ def output_batch_generator(expected_predictions): return iter(expected_predictions) -@pytest.fixture -def classes(): - return ["A", "B", "C"] - - -def map_labels(numeric_labels, classes): - return [classes[nl] for nl in numeric_labels] - - @pytest.fixture def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata): return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)] @@ -243,11 +221,3 @@ def base_patch_metadata(width, height, page_width, page_height): return metadata -@pytest.fixture(params=[220, 30]) -def page_height(request): - return request.param - - -@pytest.fixture(params=[100, 310]) -def page_width(request): - return request.param diff --git a/test/fixtures/image.py b/test/fixtures/image.py index e066657..fb1b7f3 100644 --- a/test/fixtures/image.py +++ b/test/fixtures/image.py @@ -1,5 +1,3 @@ -from operator import itemgetter - import pytest from test.utils.generation.image import array_to_image @@ -14,23 +12,3 @@ def images(input_batch): def input_size(alpha, __input_size): w, h, d = __input_size return w, h, d + alpha - - -@pytest.fixture(params=[False]) -def alpha(request): - return request.param - - -@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}]) -def __input_size(request): - return itemgetter("width", "height", "depth")(request.param) - - -@pytest.fixture(params=[33, 100]) -def height(request): - return request.param - - -@pytest.fixture(params=[10, 31]) -def width(request): - return request.param \ No newline at end of file diff --git a/test/fixtures/label.py b/test/fixtures/label.py new file mode 100644 index 0000000..2b61e47 --- /dev/null +++ b/test/fixtures/label.py @@ -0,0 +1,25 @@ +import pytest + +from image_prediction.exceptions import UnknownLabelFormat +from image_prediction.label_mapper.mappers.numeric import IndexMapper +from image_prediction.label_mapper.mappers.probability import ProbabilityMapper + + +@pytest.fixture +def label_mapper(label_format, classes): + if label_format == "index": + return IndexMapper(classes) + elif label_format == "probability": + return ProbabilityMapper(classes) + else: + raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.") + + +@pytest.fixture(params=["index"]) +def label_format(request): + return request.param + + +@pytest.fixture +def classes(): + return ["A", "B", "C"] \ No newline at end of file diff --git a/test/fixtures/parameters.py b/test/fixtures/parameters.py index e69de29..1ce4a06 100644 --- a/test/fixtures/parameters.py +++ b/test/fixtures/parameters.py @@ -0,0 +1,33 @@ +from operator import itemgetter + +import pytest + + +@pytest.fixture(params=[220, 30]) +def page_height(request): + return request.param + + +@pytest.fixture(params=[100, 310]) +def page_width(request): + return request.param + + +@pytest.fixture(params=[False]) +def alpha(request): + return request.param + + +@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}]) +def __input_size(request): + return itemgetter("width", "height", "depth")(request.param) + + +@pytest.fixture(params=[33, 100]) +def height(request): + return request.param + + +@pytest.fixture(params=[10, 31]) +def width(request): + return request.param \ No newline at end of file diff --git a/test/utils/label.py b/test/utils/label.py new file mode 100644 index 0000000..0e5babc --- /dev/null +++ b/test/utils/label.py @@ -0,0 +1,2 @@ +def map_labels(numeric_labels, classes): + return [classes[nl] for nl in numeric_labels] \ No newline at end of file