61 lines
1.5 KiB
Python
61 lines
1.5 KiB
Python
import random
|
|
import string
|
|
|
|
import pytest
|
|
|
|
from image_prediction.exceptions import UnknownDatabaseType
|
|
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
|
from image_prediction.model_loader.loader import ModelLoader
|
|
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
|
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
|
|
|
|
|
@pytest.fixture
|
|
def model_database_record_identifier():
|
|
return "".join(random.sample(string.ascii_letters, k=10))
|
|
|
|
|
|
@pytest.fixture
|
|
def model_database_record(model, classes):
|
|
return {"model": model, "classes": classes}
|
|
|
|
|
|
@pytest.fixture
|
|
def model_database(model_database_record, model_database_record_identifier):
|
|
return {model_database_record_identifier: model_database_record}
|
|
|
|
|
|
@pytest.fixture
|
|
def database_connector(database_type, model_database, mlflow_reader):
|
|
if database_type == "mock":
|
|
return DatabaseConnectorMock(model_database)
|
|
|
|
elif database_type == "mlflow":
|
|
return MlflowConnector(mlflow_reader)
|
|
|
|
else:
|
|
raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.")
|
|
|
|
|
|
@pytest.fixture
|
|
def model_loader(database_connector):
|
|
return ModelLoader(database_connector)
|
|
|
|
|
|
@pytest.fixture
|
|
def mlflow_run_id():
|
|
from image_prediction.config import CONFIG
|
|
|
|
return CONFIG.service.run_id
|
|
|
|
|
|
@pytest.fixture
|
|
def mlruns_dir():
|
|
from image_prediction.locations import MLRUNS_DIR
|
|
|
|
return MLRUNS_DIR
|
|
|
|
|
|
@pytest.fixture
|
|
def mlflow_reader(mlruns_dir):
|
|
return MlflowModelReader(mlruns_dir) |