refactoring: splitting conftest logic into submodules

This commit is contained in:
Matthias Bisping 2022-04-14 18:24:45 +02:00
parent a60b33229b
commit dbc618aab9
5 changed files with 65 additions and 58 deletions

View File

@ -2,7 +2,6 @@ import json
import logging
import os
import random
import string
from functools import partial
from itertools import starmap
from operator import itemgetter
@ -13,7 +12,6 @@ import pytest
from funcy import rcompose, merge
from image_prediction.exceptions import (
UnknownDatabaseType,
UnknownLabelFormat,
)
from image_prediction.image_extractor.extractor import ImageMetadataPair
@ -21,16 +19,15 @@ from image_prediction.info import Info
from image_prediction.label_mapper.mappers.numeric import IndexMapper
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper, ProbabilityMapperKeys
from image_prediction.locations import TEST_DATA_DIR
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
from image_prediction.pipeline import load_pipeline
from image_prediction.redai_adapter.mlflow import MlflowModelReader
from image_prediction.utils import get_logger
from test.utils.generation.image import array_to_image
from test.utils.generation.pdf import add_image, pdf_stream
pytest_plugins = ['test.utils.model']
pytest_plugins = [
'test.fixtures.model',
'test.fixtures.model_store',
]
@ -231,57 +228,6 @@ def pdf(image_metadata_pairs):
return pdf_stream(pdf)
@pytest.fixture
def model_database_record_identifier():
return "".join(random.sample(string.ascii_letters, k=10))
@pytest.fixture
def model_database_record(model, classes):
return {"model": model, "classes": classes}
@pytest.fixture
def model_database(model_database_record, model_database_record_identifier):
return {model_database_record_identifier: model_database_record}
@pytest.fixture
def database_connector(database_type, model_database, mlflow_reader):
if database_type == "mock":
return DatabaseConnectorMock(model_database)
elif database_type == "mlflow":
return MlflowConnector(mlflow_reader)
else:
raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.")
@pytest.fixture
def model_loader(database_connector):
return ModelLoader(database_connector)
@pytest.fixture
def mlflow_run_id():
from image_prediction.config import CONFIG
return CONFIG.service.run_id
@pytest.fixture
def mlruns_dir():
from image_prediction.locations import MLRUNS_DIR
return MLRUNS_DIR
@pytest.fixture
def mlflow_reader(mlruns_dir):
return MlflowModelReader(mlruns_dir)
@pytest.fixture
def real_pdf():
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f:

0
test/fixtures/input.py vendored Normal file
View File

61
test/fixtures/model_store.py vendored Normal file
View File

@ -0,0 +1,61 @@
import random
import string
import pytest
from image_prediction.exceptions import UnknownDatabaseType
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
from image_prediction.redai_adapter.mlflow import MlflowModelReader
@pytest.fixture
def model_database_record_identifier():
return "".join(random.sample(string.ascii_letters, k=10))
@pytest.fixture
def model_database_record(model, classes):
return {"model": model, "classes": classes}
@pytest.fixture
def model_database(model_database_record, model_database_record_identifier):
return {model_database_record_identifier: model_database_record}
@pytest.fixture
def database_connector(database_type, model_database, mlflow_reader):
if database_type == "mock":
return DatabaseConnectorMock(model_database)
elif database_type == "mlflow":
return MlflowConnector(mlflow_reader)
else:
raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.")
@pytest.fixture
def model_loader(database_connector):
return ModelLoader(database_connector)
@pytest.fixture
def mlflow_run_id():
from image_prediction.config import CONFIG
return CONFIG.service.run_id
@pytest.fixture
def mlruns_dir():
from image_prediction.locations import MLRUNS_DIR
return MLRUNS_DIR
@pytest.fixture
def mlflow_reader(mlruns_dir):
return MlflowModelReader(mlruns_dir)