Matthias Bisping 99d8e921db renaming
2022-03-30 13:57:29 +02:00

26 lines
666 B
Python

import logging
import pytest
from image_prediction.utils import get_logger
logger = get_logger()
logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
def test_predict(classifier, input_batch, expected_batch_string_labels):
predictions = classifier.predict(input_batch)
assert predictions == expected_batch_string_labels
def test_batch_format(input_batch):
def channels_are_last(input_batch):
return input_batch.shape[-1] == 3
def is_fourth_order_tensor(input_batch):
return input_batch.ndim == 4
assert channels_are_last(input_batch)
assert is_fourth_order_tensor(input_batch)