22 lines
578 B
Python
22 lines
578 B
Python
import numpy as np
|
|
import pytest
|
|
from dvc.repo import Repo
|
|
|
|
from image_prediction.locations import PACKAGE_ROOT_DIR, TEST_DATA_DIR_DVC
|
|
from image_prediction.utils import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
@pytest.fixture
|
|
def input_batch(batch_size, input_size):
|
|
return np.random.random_sample(size=(batch_size, *input_size))
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def dvc_test_data():
|
|
logger.info("Pulling data with DVC...")
|
|
# noinspection PyCallingNonCallable
|
|
Repo(PACKAGE_ROOT_DIR).pull(targets=[str(TEST_DATA_DIR_DVC)])
|
|
logger.info("Finished pulling data.")
|