testing index and probability label format in classifier prediction test
This commit is contained in:
parent
49f9847d9a
commit
a5d3232dd0
@ -14,6 +14,10 @@ class UnknownDatabaseType(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownLabelFormat(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class UnexpectedLabelFormat(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
@ -9,9 +9,10 @@ logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||
def test_predict(classifier, input_batch, batch_of_expected_string_labels):
|
||||
@pytest.mark.parametrize("label_format", ["index", "probability"])
|
||||
def test_predict(classifier, input_batch, expected_predictions_mapped):
|
||||
predictions = classifier.predict(input_batch)
|
||||
assert predictions == batch_of_expected_string_labels
|
||||
assert predictions == expected_predictions_mapped
|
||||
|
||||
|
||||
def test_batch_format(input_batch):
|
||||
|
||||
@ -13,12 +13,14 @@ from image_prediction.classifier.classifier import Classifier
|
||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType
|
||||
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType, \
|
||||
UnknownLabelFormat
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
||||
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
||||
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
||||
from image_prediction.model_loader.loader import ModelLoader
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||
@ -57,9 +59,43 @@ class EstimatorMock:
|
||||
return self.predict(batch)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def label_mapper(classes):
|
||||
return IndexMapper(classes)
|
||||
@pytest.fixture
|
||||
def label_mapper(label_format, classes):
|
||||
if label_format == "index":
|
||||
return IndexMapper(classes)
|
||||
elif label_format == "probability":
|
||||
return ProbabilityMapper(classes)
|
||||
else:
|
||||
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
|
||||
|
||||
|
||||
@pytest.fixture(params=["index"])
|
||||
def label_format(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_predictions_mapped(
|
||||
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
|
||||
):
|
||||
if label_format == "index":
|
||||
return batch_of_expected_string_labels
|
||||
elif label_format == "probability":
|
||||
return batch_of_expected_label_to_probability_mappings
|
||||
else:
|
||||
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_predictions(
|
||||
label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays
|
||||
):
|
||||
if label_format == "index":
|
||||
return batch_of_expected_numeric_labels
|
||||
elif label_format == "probability":
|
||||
return batch_of_expected_probability_arrays
|
||||
else:
|
||||
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -161,8 +197,8 @@ def batch_of_expected_probability_arrays(batch_size, classes):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def output_batch_generator(batch_of_expected_numeric_labels):
|
||||
return iter(batch_of_expected_numeric_labels)
|
||||
def output_batch_generator(expected_predictions):
|
||||
return iter(expected_predictions)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user