import numpy as np import pytest @pytest.fixture def input_batch(batch_size, input_size): return np.random.random_sample(size=(batch_size, *input_size))