2022-03-25 11:23:07 +01:00

37 lines
1011 B
Python

import numpy as np
import pytest
from image_prediction.estimator.mock import EstimatorMock
from image_prediction.model.mock import ModelMock
@pytest.fixture(scope="session")
def estimator():
return EstimatorMock()
@pytest.fixture(scope="session")
def batches(batch_size):
input_batch = np.random.normal(size=(batch_size, 10, 15))
output_batch = np.random.randint(low=42, high=43, size=(batch_size, 10, 15))
return input_batch, output_batch
@pytest.fixture(scope="session")
def classes():
return ["A", "B", "C"]
@pytest.fixture(scope="session")
def model(model_type, estimator):
if model_type == "mock":
return ModelMock(estimator)
@pytest.mark.parametrize("model_type", ["mock"], scope="session")
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
def test_predict(model, batches):
input_batch, output_batch = batches
model.estimator.output_batch = output_batch
assert np.all(np.equal(model.predict(input_batch), output_batch))