20 lines
583 B
Python
20 lines
583 B
Python
from image_prediction.exceptions import UnknownModelLoader
|
|
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
|
|
|
|
|
|
def get_mlflow_loader():
|
|
from image_prediction.locations import BASE_WEIGHTS
|
|
from image_prediction.config import CONFIG
|
|
|
|
loader = MlflowLoader(CONFIG.service.run_id)
|
|
loader._base_weights = BASE_WEIGHTS
|
|
|
|
return loader
|
|
|
|
|
|
def get_model_loader(loader_type: type):
|
|
if loader_type == MlflowLoader:
|
|
return get_mlflow_loader()
|
|
else:
|
|
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
|