22 lines
857 B
Python
22 lines
857 B
Python
import pytest
|
|
|
|
from image_prediction.model_loader.loaders.mlflow import PredictionModelHandle
|
|
|
|
|
|
@pytest.mark.parametrize("database_type", ["mock"])
|
|
def test_load_model_and_classes(model_loader, model_database_record_identifier, model, classes):
|
|
model_loaded = model_loader.load_model(model_database_record_identifier)
|
|
classes_loaded = model_loader.load_classes(model_database_record_identifier)
|
|
|
|
assert model_loaded == model
|
|
assert classes_loaded == classes
|
|
|
|
|
|
@pytest.mark.parametrize("database_type", ["mlflow"])
|
|
def test_load_model_and_classes_from_mlflow_store(model_loader, mlflow_run_id):
|
|
model_loaded = model_loader.load_model(mlflow_run_id)
|
|
classes_loaded = model_loader.load_classes(mlflow_run_id)
|
|
|
|
assert type(model_loaded) == PredictionModelHandle
|
|
assert classes_loaded == ['formula', 'logo', 'other', 'signature']
|