diff --git a/test/unit_tests/classifier_test.py b/test/unit_tests/classifier_test.py index a6a695c..9a34489 100644 --- a/test/unit_tests/classifier_test.py +++ b/test/unit_tests/classifier_test.py @@ -10,8 +10,8 @@ logger.setLevel(logging.DEBUG) @pytest.mark.parametrize("estimator_type", ["mock", "keras"]) @pytest.mark.parametrize("label_format", ["index", "probability"]) -def test_predict(classifier, input_batch, expected_predictions_mapped): - predictions = classifier.predict(input_batch) +def test_classifier(classifier, input_batch, expected_predictions_mapped): + predictions = classifier(input_batch) assert predictions == expected_predictions_mapped