import random import string import pytest from image_prediction.exceptions import UnknownDatabaseType from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loaders.mlflow import MlflowConnector from image_prediction.redai_adapter.mlflow import MlflowModelReader @pytest.fixture def model_database_record_identifier(): return "".join(random.sample(string.ascii_letters, k=10)) @pytest.fixture def model_database_record(model, classes): return {"model": model, "classes": classes} @pytest.fixture def model_database(model_database_record, model_database_record_identifier): return {model_database_record_identifier: model_database_record} @pytest.fixture def database_connector(database_type, model_database, mlflow_reader): if database_type == "mock": return DatabaseConnectorMock(model_database) elif database_type == "mlflow": return MlflowConnector(mlflow_reader) else: raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.") @pytest.fixture def model_loader(database_connector): return ModelLoader(database_connector) @pytest.fixture def mlflow_run_id(): from image_prediction.config import CONFIG return CONFIG.service.mlflow_run_id @pytest.fixture def mlruns_dir(): from image_prediction.locations import MLRUNS_DIR return MLRUNS_DIR @pytest.fixture def mlflow_reader(mlruns_dir): return MlflowModelReader(mlruns_dir)