From 6b58756103effebf8466e71e50334f0b3a3c9804 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Tue, 29 Mar 2022 11:02:43 +0200 Subject: [PATCH] refactoring of mlflow model loader --- image_prediction/exceptions.py | 4 +++ image_prediction/model_loader/loader.py | 4 +-- .../model_loader/loaders/loaders.py | 19 ++++++++++++++ .../model_loader/loaders/mlflow.py | 26 ++++++++++++++++--- test/unit_tests/conftest.py | 3 ++- test/unit_tests/model_loader_test.py | 8 +++--- 6 files changed, 54 insertions(+), 10 deletions(-) create mode 100644 image_prediction/model_loader/loaders/loaders.py diff --git a/image_prediction/exceptions.py b/image_prediction/exceptions.py index 6982dff..fe0136c 100644 --- a/image_prediction/exceptions.py +++ b/image_prediction/exceptions.py @@ -8,3 +8,7 @@ class UnknownImageExtractor(ValueError): class UnknownModelLoader(ValueError): pass + + +class IncorrectInstantiation(RuntimeError): + pass diff --git a/image_prediction/model_loader/loader.py b/image_prediction/model_loader/loader.py index b78543e..32e76d6 100644 --- a/image_prediction/model_loader/loader.py +++ b/image_prediction/model_loader/loader.py @@ -4,9 +4,9 @@ import abc class ModelLoader(abc.ABC): @abc.abstractmethod - def load_model(self, identifier): + def load_model(self, *args, **kwargs): pass @abc.abstractmethod - def load_classes(self, identifier): + def load_classes(self, *args, **kwargs): pass diff --git a/image_prediction/model_loader/loaders/loaders.py b/image_prediction/model_loader/loaders/loaders.py new file mode 100644 index 0000000..ef9fee8 --- /dev/null +++ b/image_prediction/model_loader/loaders/loaders.py @@ -0,0 +1,19 @@ +from image_prediction.exceptions import UnknownModelLoader +from image_prediction.model_loader.loaders.mlflow import MlflowLoader + + +def get_mlflow_loader(): + from image_prediction.locations import BASE_WEIGHTS + from image_prediction.config import CONFIG + + loader = MlflowLoader(CONFIG.service.run_id) + loader._base_weights = BASE_WEIGHTS + + return loader + + +def get_model_loader(loader_type: type): + if loader_type == MlflowLoader: + return get_mlflow_loader() + else: + raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.") diff --git a/image_prediction/model_loader/loaders/mlflow.py b/image_prediction/model_loader/loaders/mlflow.py index dcbce5e..3328b5d 100644 --- a/image_prediction/model_loader/loaders/mlflow.py +++ b/image_prediction/model_loader/loaders/mlflow.py @@ -1,3 +1,10 @@ +"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and +MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the +need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the +MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a +non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow +based solution or look into an alternative such as WandB or use a platform solution such as AWS. +""" import importlib import json import os @@ -5,7 +12,7 @@ import warnings import numpy as np -from image_prediction.locations import BASE_WEIGHTS +from image_prediction.exceptions import IncorrectInstantiation from image_prediction.model_loader.loader import ModelLoader warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") @@ -64,12 +71,23 @@ class MlflowLoader(ModelLoader): def __init__(self, mlruns_dir): self.__mlruns_dir = mlruns_dir self._model_handle = None + self.__last_run_id = None + self._base_weights = None - def load_model(self, run_id): - if not self._model_handle: + def load_model(self, run_id, base_weights=None): + + if not base_weights: + + if not self._base_weights: + raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.") + + base_weights = self._base_weights + + if not self._model_handle and run_id == self.__last_run_id: mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir) - model_handel = mlflow_reader.get_model_handle(BASE_WEIGHTS) + model_handel = mlflow_reader.get_model_handle(base_weights) self._model_handle = model_handel + self.__last_run_id = run_id return self._model_handle diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index 4b281d8..6506c1d 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -16,6 +16,7 @@ from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExt from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractors.mock import ImageExtractorMock from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor +from image_prediction.model_loader.loaders.loaders import get_mlflow_loader from image_prediction.model_loader.loaders.mlflow import MlflowLoader from image_prediction.model_loader.loaders.mock import ModelLoaderMock @@ -225,7 +226,7 @@ def model_loader(loader_type, monkeypatch, model_handle_mock, classes): monkeypatch.setattr(loader, "model", model_handle_mock) monkeypatch.setattr(loader, "classes", classes) elif loader_type == "mlflow": - loader = MlflowLoader("...") + loader = get_mlflow_loader() monkeypatch.setattr(loader, "_model_handle", model_handle_mock) else: raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.") diff --git a/test/unit_tests/model_loader_test.py b/test/unit_tests/model_loader_test.py index 3139c7a..8b4a64a 100644 --- a/test/unit_tests/model_loader_test.py +++ b/test/unit_tests/model_loader_test.py @@ -8,6 +8,8 @@ from image_prediction.model_loading import load_model_and_classes @pytest.mark.parametrize("estimator_type", ["mock"]) @pytest.mark.parametrize("batch_size", [3]) def test_load_model_and_classes(model_loader, model_handle_mock, classes): - model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader) - assert model_loaded == model_handle_mock - assert np.all(classes_loaded == classes) + # Load twice to test caching logic + for _ in range(2): + model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader) + assert model_loaded == model_handle_mock + assert np.all(classes_loaded == classes)