refactoring of mlflow model loader

This commit is contained in:
Matthias Bisping 2022-03-29 11:02:43 +02:00
parent 3b4c2a40b2
commit 6b58756103
6 changed files with 54 additions and 10 deletions

View File

@ -8,3 +8,7 @@ class UnknownImageExtractor(ValueError):
class UnknownModelLoader(ValueError): class UnknownModelLoader(ValueError):
pass pass
class IncorrectInstantiation(RuntimeError):
pass

View File

@ -4,9 +4,9 @@ import abc
class ModelLoader(abc.ABC): class ModelLoader(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def load_model(self, identifier): def load_model(self, *args, **kwargs):
pass pass
@abc.abstractmethod @abc.abstractmethod
def load_classes(self, identifier): def load_classes(self, *args, **kwargs):
pass pass

View File

@ -0,0 +1,19 @@
from image_prediction.exceptions import UnknownModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
def get_mlflow_loader():
from image_prediction.locations import BASE_WEIGHTS
from image_prediction.config import CONFIG
loader = MlflowLoader(CONFIG.service.run_id)
loader._base_weights = BASE_WEIGHTS
return loader
def get_model_loader(loader_type: type):
if loader_type == MlflowLoader:
return get_mlflow_loader()
else:
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")

View File

@ -1,3 +1,10 @@
"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
based solution or look into an alternative such as WandB or use a platform solution such as AWS.
"""
import importlib import importlib
import json import json
import os import os
@ -5,7 +12,7 @@ import warnings
import numpy as np import numpy as np
from image_prediction.locations import BASE_WEIGHTS from image_prediction.exceptions import IncorrectInstantiation
from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loader import ModelLoader
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
@ -64,12 +71,23 @@ class MlflowLoader(ModelLoader):
def __init__(self, mlruns_dir): def __init__(self, mlruns_dir):
self.__mlruns_dir = mlruns_dir self.__mlruns_dir = mlruns_dir
self._model_handle = None self._model_handle = None
self.__last_run_id = None
self._base_weights = None
def load_model(self, run_id): def load_model(self, run_id, base_weights=None):
if not self._model_handle:
if not base_weights:
if not self._base_weights:
raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.")
base_weights = self._base_weights
if not self._model_handle and run_id == self.__last_run_id:
mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir) mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
model_handel = mlflow_reader.get_model_handle(BASE_WEIGHTS) model_handel = mlflow_reader.get_model_handle(base_weights)
self._model_handle = model_handel self._model_handle = model_handel
self.__last_run_id = run_id
return self._model_handle return self._model_handle

View File

@ -16,6 +16,7 @@ from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExt
from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.model_loader.loaders.loaders import get_mlflow_loader
from image_prediction.model_loader.loaders.mlflow import MlflowLoader from image_prediction.model_loader.loaders.mlflow import MlflowLoader
from image_prediction.model_loader.loaders.mock import ModelLoaderMock from image_prediction.model_loader.loaders.mock import ModelLoaderMock
@ -225,7 +226,7 @@ def model_loader(loader_type, monkeypatch, model_handle_mock, classes):
monkeypatch.setattr(loader, "model", model_handle_mock) monkeypatch.setattr(loader, "model", model_handle_mock)
monkeypatch.setattr(loader, "classes", classes) monkeypatch.setattr(loader, "classes", classes)
elif loader_type == "mlflow": elif loader_type == "mlflow":
loader = MlflowLoader("...") loader = get_mlflow_loader()
monkeypatch.setattr(loader, "_model_handle", model_handle_mock) monkeypatch.setattr(loader, "_model_handle", model_handle_mock)
else: else:
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.") raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")

View File

@ -8,6 +8,8 @@ from image_prediction.model_loading import load_model_and_classes
@pytest.mark.parametrize("estimator_type", ["mock"]) @pytest.mark.parametrize("estimator_type", ["mock"])
@pytest.mark.parametrize("batch_size", [3]) @pytest.mark.parametrize("batch_size", [3])
def test_load_model_and_classes(model_loader, model_handle_mock, classes): def test_load_model_and_classes(model_loader, model_handle_mock, classes):
model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader) # Load twice to test caching logic
assert model_loaded == model_handle_mock for _ in range(2):
assert np.all(classes_loaded == classes) model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader)
assert model_loaded == model_handle_mock
assert np.all(classes_loaded == classes)