testing index and probability label format in classifier prediction test
This commit is contained in:
parent
49f9847d9a
commit
a5d3232dd0
@ -14,6 +14,10 @@ class UnknownDatabaseType(ValueError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownLabelFormat(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UnexpectedLabelFormat(ValueError):
|
class UnexpectedLabelFormat(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -9,9 +9,10 @@ logger.setLevel(logging.DEBUG)
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||||
def test_predict(classifier, input_batch, batch_of_expected_string_labels):
|
@pytest.mark.parametrize("label_format", ["index", "probability"])
|
||||||
|
def test_predict(classifier, input_batch, expected_predictions_mapped):
|
||||||
predictions = classifier.predict(input_batch)
|
predictions = classifier.predict(input_batch)
|
||||||
assert predictions == batch_of_expected_string_labels
|
assert predictions == expected_predictions_mapped
|
||||||
|
|
||||||
|
|
||||||
def test_batch_format(input_batch):
|
def test_batch_format(input_batch):
|
||||||
|
|||||||
@ -13,12 +13,14 @@ from image_prediction.classifier.classifier import Classifier
|
|||||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||||
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||||
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType
|
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType, \
|
||||||
|
UnknownLabelFormat
|
||||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
||||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||||
from image_prediction.info import Info
|
from image_prediction.info import Info
|
||||||
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
||||||
|
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
||||||
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
||||||
from image_prediction.model_loader.loader import ModelLoader
|
from image_prediction.model_loader.loader import ModelLoader
|
||||||
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||||
@ -57,9 +59,43 @@ class EstimatorMock:
|
|||||||
return self.predict(batch)
|
return self.predict(batch)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def label_mapper(classes):
|
def label_mapper(label_format, classes):
|
||||||
return IndexMapper(classes)
|
if label_format == "index":
|
||||||
|
return IndexMapper(classes)
|
||||||
|
elif label_format == "probability":
|
||||||
|
return ProbabilityMapper(classes)
|
||||||
|
else:
|
||||||
|
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["index"])
|
||||||
|
def label_format(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_predictions_mapped(
|
||||||
|
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
|
||||||
|
):
|
||||||
|
if label_format == "index":
|
||||||
|
return batch_of_expected_string_labels
|
||||||
|
elif label_format == "probability":
|
||||||
|
return batch_of_expected_label_to_probability_mappings
|
||||||
|
else:
|
||||||
|
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_predictions(
|
||||||
|
label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays
|
||||||
|
):
|
||||||
|
if label_format == "index":
|
||||||
|
return batch_of_expected_numeric_labels
|
||||||
|
elif label_format == "probability":
|
||||||
|
return batch_of_expected_probability_arrays
|
||||||
|
else:
|
||||||
|
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -161,8 +197,8 @@ def batch_of_expected_probability_arrays(batch_size, classes):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def output_batch_generator(batch_of_expected_numeric_labels):
|
def output_batch_generator(expected_predictions):
|
||||||
return iter(batch_of_expected_numeric_labels)
|
return iter(expected_predictions)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user