diff --git a/image_prediction/locations.py b/image_prediction/locations.py index 4bf8d4b..9dfe2f4 100644 --- a/image_prediction/locations.py +++ b/image_prediction/locations.py @@ -7,4 +7,3 @@ CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml") DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data") MLRUNS_DIR = path.join(DATA_DIR, "mlruns") -BASE_WEIGHTS = path.join(DATA_DIR, "base_weights.h5") diff --git a/image_prediction/model_loader/loaders/loaders.py b/image_prediction/model_loader/loaders/loaders.py deleted file mode 100644 index ef9fee8..0000000 --- a/image_prediction/model_loader/loaders/loaders.py +++ /dev/null @@ -1,19 +0,0 @@ -from image_prediction.exceptions import UnknownModelLoader -from image_prediction.model_loader.loaders.mlflow import MlflowLoader - - -def get_mlflow_loader(): - from image_prediction.locations import BASE_WEIGHTS - from image_prediction.config import CONFIG - - loader = MlflowLoader(CONFIG.service.run_id) - loader._base_weights = BASE_WEIGHTS - - return loader - - -def get_model_loader(loader_type: type): - if loader_type == MlflowLoader: - return get_mlflow_loader() - else: - raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")