Matthias Bisping 0cf8e047c5 Refactoring
2023-02-06 13:22:33 +01:00

22 lines
578 B
Python

import numpy as np
import pytest
from dvc.repo import Repo
from image_prediction.locations import PACKAGE_ROOT_DIR, TEST_DATA_DIR_DVC
from image_prediction.utils import get_logger
logger = get_logger()
@pytest.fixture
def input_batch(batch_size, input_size):
return np.random.random_sample(size=(batch_size, *input_size))
@pytest.fixture(scope="session")
def dvc_test_data():
logger.info("Pulling data with DVC...")
# noinspection PyCallingNonCallable
Repo(PACKAGE_ROOT_DIR).pull(targets=[str(TEST_DATA_DIR_DVC)])
logger.info("Finished pulling data.")