22 lines
857 B
Python

import pytest
from image_prediction.model_loader.loaders.mlflow import PredictionModelHandle
@pytest.mark.parametrize("database_type", ["mock"])
def test_load_model_and_classes(model_loader, model_database_record_identifier, model, classes):
model_loaded = model_loader.load_model(model_database_record_identifier)
classes_loaded = model_loader.load_classes(model_database_record_identifier)
assert model_loaded == model
assert classes_loaded == classes
@pytest.mark.parametrize("database_type", ["mlflow"])
def test_load_model_and_classes_from_mlflow_store(model_loader, mlflow_run_id):
model_loaded = model_loader.load_model(mlflow_run_id)
classes_loaded = model_loader.load_classes(mlflow_run_id)
assert type(model_loaded) == PredictionModelHandle
assert classes_loaded == ['formula', 'logo', 'other', 'signature']