diff --git a/image_prediction/exceptions.py b/image_prediction/exceptions.py index 3b82070..476f1e8 100644 --- a/image_prediction/exceptions.py +++ b/image_prediction/exceptions.py @@ -14,6 +14,10 @@ class UnknownDatabaseType(ValueError): pass +class UnknownLabelFormat(ValueError): + pass + + class UnexpectedLabelFormat(ValueError): pass diff --git a/test/unit_tests/classifier_test.py b/test/unit_tests/classifier_test.py index 699154e..a6a695c 100644 --- a/test/unit_tests/classifier_test.py +++ b/test/unit_tests/classifier_test.py @@ -9,9 +9,10 @@ logger.setLevel(logging.DEBUG) @pytest.mark.parametrize("estimator_type", ["mock", "keras"]) -def test_predict(classifier, input_batch, batch_of_expected_string_labels): +@pytest.mark.parametrize("label_format", ["index", "probability"]) +def test_predict(classifier, input_batch, expected_predictions_mapped): predictions = classifier.predict(input_batch) - assert predictions == batch_of_expected_string_labels + assert predictions == expected_predictions_mapped def test_batch_format(input_batch): diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index 5fc55d6..cc04501 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -13,12 +13,14 @@ from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor -from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType +from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType, \ + UnknownLabelFormat from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractors.mock import ImageExtractorMock from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.info import Info from image_prediction.label_mapper.mappers.numeric import IndexMapper +from image_prediction.label_mapper.mappers.probability import ProbabilityMapper from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loaders.mlflow import MlflowConnector @@ -57,9 +59,43 @@ class EstimatorMock: return self.predict(batch) -@pytest.fixture() -def label_mapper(classes): - return IndexMapper(classes) +@pytest.fixture +def label_mapper(label_format, classes): + if label_format == "index": + return IndexMapper(classes) + elif label_format == "probability": + return ProbabilityMapper(classes) + else: + raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.") + + +@pytest.fixture(params=["index"]) +def label_format(request): + return request.param + + +@pytest.fixture +def expected_predictions_mapped( + label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings +): + if label_format == "index": + return batch_of_expected_string_labels + elif label_format == "probability": + return batch_of_expected_label_to_probability_mappings + else: + raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.") + + +@pytest.fixture +def expected_predictions( + label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays +): + if label_format == "index": + return batch_of_expected_numeric_labels + elif label_format == "probability": + return batch_of_expected_probability_arrays + else: + raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.") @pytest.fixture @@ -161,8 +197,8 @@ def batch_of_expected_probability_arrays(batch_size, classes): @pytest.fixture -def output_batch_generator(batch_of_expected_numeric_labels): - return iter(batch_of_expected_numeric_labels) +def output_batch_generator(expected_predictions): + return iter(expected_predictions) @pytest.fixture