testing index and probability label format in classifier prediction test

This commit is contained in:
Matthias Bisping 2022-03-30 16:34:17 +02:00
parent 49f9847d9a
commit a5d3232dd0
3 changed files with 49 additions and 8 deletions

View File

@ -14,6 +14,10 @@ class UnknownDatabaseType(ValueError):
pass
class UnknownLabelFormat(ValueError):
pass
class UnexpectedLabelFormat(ValueError):
pass

View File

@ -9,9 +9,10 @@ logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
def test_predict(classifier, input_batch, batch_of_expected_string_labels):
@pytest.mark.parametrize("label_format", ["index", "probability"])
def test_predict(classifier, input_batch, expected_predictions_mapped):
predictions = classifier.predict(input_batch)
assert predictions == batch_of_expected_string_labels
assert predictions == expected_predictions_mapped
def test_batch_format(input_batch):

View File

@ -13,12 +13,14 @@ from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType, \
UnknownLabelFormat
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.info import Info
from image_prediction.label_mapper.mappers.numeric import IndexMapper
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
@ -57,9 +59,43 @@ class EstimatorMock:
return self.predict(batch)
@pytest.fixture()
def label_mapper(classes):
return IndexMapper(classes)
@pytest.fixture
def label_mapper(label_format, classes):
if label_format == "index":
return IndexMapper(classes)
elif label_format == "probability":
return ProbabilityMapper(classes)
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture(params=["index"])
def label_format(request):
return request.param
@pytest.fixture
def expected_predictions_mapped(
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
):
if label_format == "index":
return batch_of_expected_string_labels
elif label_format == "probability":
return batch_of_expected_label_to_probability_mappings
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture
def expected_predictions(
label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays
):
if label_format == "index":
return batch_of_expected_numeric_labels
elif label_format == "probability":
return batch_of_expected_probability_arrays
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture
@ -161,8 +197,8 @@ def batch_of_expected_probability_arrays(batch_size, classes):
@pytest.fixture
def output_batch_generator(batch_of_expected_numeric_labels):
return iter(batch_of_expected_numeric_labels)
def output_batch_generator(expected_predictions):
return iter(expected_predictions)
@pytest.fixture