22 lines
849 B
Python
22 lines
849 B
Python
import pytest
|
|
|
|
from image_prediction.redai_adapter.model import PredictionModelHandle
|
|
|
|
|
|
@pytest.mark.parametrize("database_type", ["mock"])
|
|
def test_load_model_and_classes(model_loader, model_database_record_identifier, model, classes):
|
|
model_loaded = model_loader.load_model(model_database_record_identifier)
|
|
classes_loaded = model_loader.load_classes(model_database_record_identifier)
|
|
|
|
assert model_loaded == model
|
|
assert classes_loaded == classes
|
|
|
|
|
|
@pytest.mark.parametrize("database_type", ["mlflow"])
|
|
def test_load_model_and_classes_from_mlflow_store(model_loader, mlflow_run_id):
|
|
model_loaded = model_loader.load_model(mlflow_run_id)
|
|
classes_loaded = model_loader.load_classes(mlflow_run_id)
|
|
|
|
assert type(model_loaded) == PredictionModelHandle
|
|
assert classes_loaded == ["formula", "logo", "other", "signature"]
|