image-classification-service/test/unit_tests/service_estimator_test.py

16 lines
467 B
Python

import logging
import pytest
from image_prediction.utils import get_logger
logger = get_logger()
logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
def test_predict(service_estimator, input_batch, expected_predictions):
predictions = service_estimator.predict(input_batch)
assert predictions == expected_predictions