diff --git a/test/conftest.py b/test/conftest.py index 95cd2b9..582a3db 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,7 +2,6 @@ import json import logging import os import random -import string from functools import partial from itertools import starmap from operator import itemgetter @@ -13,7 +12,6 @@ import pytest from funcy import rcompose, merge from image_prediction.exceptions import ( - UnknownDatabaseType, UnknownLabelFormat, ) from image_prediction.image_extractor.extractor import ImageMetadataPair @@ -21,16 +19,15 @@ from image_prediction.info import Info from image_prediction.label_mapper.mappers.numeric import IndexMapper from image_prediction.label_mapper.mappers.probability import ProbabilityMapper, ProbabilityMapperKeys from image_prediction.locations import TEST_DATA_DIR -from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock -from image_prediction.model_loader.loader import ModelLoader -from image_prediction.model_loader.loaders.mlflow import MlflowConnector from image_prediction.pipeline import load_pipeline -from image_prediction.redai_adapter.mlflow import MlflowModelReader from image_prediction.utils import get_logger from test.utils.generation.image import array_to_image from test.utils.generation.pdf import add_image, pdf_stream -pytest_plugins = ['test.utils.model'] +pytest_plugins = [ + 'test.fixtures.model', + 'test.fixtures.model_store', +] @@ -231,57 +228,6 @@ def pdf(image_metadata_pairs): return pdf_stream(pdf) -@pytest.fixture -def model_database_record_identifier(): - return "".join(random.sample(string.ascii_letters, k=10)) - - -@pytest.fixture -def model_database_record(model, classes): - return {"model": model, "classes": classes} - - -@pytest.fixture -def model_database(model_database_record, model_database_record_identifier): - return {model_database_record_identifier: model_database_record} - - -@pytest.fixture -def database_connector(database_type, model_database, mlflow_reader): - if database_type == "mock": - return DatabaseConnectorMock(model_database) - - elif database_type == "mlflow": - return MlflowConnector(mlflow_reader) - - else: - raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.") - - -@pytest.fixture -def model_loader(database_connector): - return ModelLoader(database_connector) - - -@pytest.fixture -def mlflow_run_id(): - from image_prediction.config import CONFIG - - return CONFIG.service.run_id - - -@pytest.fixture -def mlruns_dir(): - from image_prediction.locations import MLRUNS_DIR - - return MLRUNS_DIR - - -@pytest.fixture -def mlflow_reader(mlruns_dir): - return MlflowModelReader(mlruns_dir) - - @pytest.fixture def real_pdf(): with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f: diff --git a/test/utils/model_store.py b/test/fixtures/image.py similarity index 100% rename from test/utils/model_store.py rename to test/fixtures/image.py diff --git a/test/fixtures/input.py b/test/fixtures/input.py new file mode 100644 index 0000000..e69de29 diff --git a/test/utils/model.py b/test/fixtures/model.py similarity index 100% rename from test/utils/model.py rename to test/fixtures/model.py diff --git a/test/fixtures/model_store.py b/test/fixtures/model_store.py new file mode 100644 index 0000000..bc71634 --- /dev/null +++ b/test/fixtures/model_store.py @@ -0,0 +1,61 @@ +import random +import string + +import pytest + +from image_prediction.exceptions import UnknownDatabaseType +from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock +from image_prediction.model_loader.loader import ModelLoader +from image_prediction.model_loader.loaders.mlflow import MlflowConnector +from image_prediction.redai_adapter.mlflow import MlflowModelReader + + +@pytest.fixture +def model_database_record_identifier(): + return "".join(random.sample(string.ascii_letters, k=10)) + + +@pytest.fixture +def model_database_record(model, classes): + return {"model": model, "classes": classes} + + +@pytest.fixture +def model_database(model_database_record, model_database_record_identifier): + return {model_database_record_identifier: model_database_record} + + +@pytest.fixture +def database_connector(database_type, model_database, mlflow_reader): + if database_type == "mock": + return DatabaseConnectorMock(model_database) + + elif database_type == "mlflow": + return MlflowConnector(mlflow_reader) + + else: + raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.") + + +@pytest.fixture +def model_loader(database_connector): + return ModelLoader(database_connector) + + +@pytest.fixture +def mlflow_run_id(): + from image_prediction.config import CONFIG + + return CONFIG.service.run_id + + +@pytest.fixture +def mlruns_dir(): + from image_prediction.locations import MLRUNS_DIR + + return MLRUNS_DIR + + +@pytest.fixture +def mlflow_reader(mlruns_dir): + return MlflowModelReader(mlruns_dir) \ No newline at end of file