2022-03-29 11:02:43 +02:00

20 lines
583 B
Python

from image_prediction.exceptions import UnknownModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
def get_mlflow_loader():
from image_prediction.locations import BASE_WEIGHTS
from image_prediction.config import CONFIG
loader = MlflowLoader(CONFIG.service.run_id)
loader._base_weights = BASE_WEIGHTS
return loader
def get_model_loader(loader_type: type):
if loader_type == MlflowLoader:
return get_mlflow_loader()
else:
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")