refactoring of mlflow model loader
This commit is contained in:
parent
3b4c2a40b2
commit
6b58756103
@ -8,3 +8,7 @@ class UnknownImageExtractor(ValueError):
|
|||||||
|
|
||||||
class UnknownModelLoader(ValueError):
|
class UnknownModelLoader(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class IncorrectInstantiation(RuntimeError):
|
||||||
|
pass
|
||||||
|
|||||||
@ -4,9 +4,9 @@ import abc
|
|||||||
class ModelLoader(abc.ABC):
|
class ModelLoader(abc.ABC):
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def load_model(self, identifier):
|
def load_model(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def load_classes(self, identifier):
|
def load_classes(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|||||||
19
image_prediction/model_loader/loaders/loaders.py
Normal file
19
image_prediction/model_loader/loaders/loaders.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from image_prediction.exceptions import UnknownModelLoader
|
||||||
|
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
|
||||||
|
|
||||||
|
|
||||||
|
def get_mlflow_loader():
|
||||||
|
from image_prediction.locations import BASE_WEIGHTS
|
||||||
|
from image_prediction.config import CONFIG
|
||||||
|
|
||||||
|
loader = MlflowLoader(CONFIG.service.run_id)
|
||||||
|
loader._base_weights = BASE_WEIGHTS
|
||||||
|
|
||||||
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_loader(loader_type: type):
|
||||||
|
if loader_type == MlflowLoader:
|
||||||
|
return get_mlflow_loader()
|
||||||
|
else:
|
||||||
|
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
|
||||||
@ -1,3 +1,10 @@
|
|||||||
|
"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
|
||||||
|
MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
|
||||||
|
need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
|
||||||
|
MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
|
||||||
|
non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
|
||||||
|
based solution or look into an alternative such as WandB or use a platform solution such as AWS.
|
||||||
|
"""
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -5,7 +12,7 @@ import warnings
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from image_prediction.locations import BASE_WEIGHTS
|
from image_prediction.exceptions import IncorrectInstantiation
|
||||||
from image_prediction.model_loader.loader import ModelLoader
|
from image_prediction.model_loader.loader import ModelLoader
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
|
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
|
||||||
@ -64,12 +71,23 @@ class MlflowLoader(ModelLoader):
|
|||||||
def __init__(self, mlruns_dir):
|
def __init__(self, mlruns_dir):
|
||||||
self.__mlruns_dir = mlruns_dir
|
self.__mlruns_dir = mlruns_dir
|
||||||
self._model_handle = None
|
self._model_handle = None
|
||||||
|
self.__last_run_id = None
|
||||||
|
self._base_weights = None
|
||||||
|
|
||||||
def load_model(self, run_id):
|
def load_model(self, run_id, base_weights=None):
|
||||||
if not self._model_handle:
|
|
||||||
|
if not base_weights:
|
||||||
|
|
||||||
|
if not self._base_weights:
|
||||||
|
raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.")
|
||||||
|
|
||||||
|
base_weights = self._base_weights
|
||||||
|
|
||||||
|
if not self._model_handle and run_id == self.__last_run_id:
|
||||||
mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
|
mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
|
||||||
model_handel = mlflow_reader.get_model_handle(BASE_WEIGHTS)
|
model_handel = mlflow_reader.get_model_handle(base_weights)
|
||||||
self._model_handle = model_handel
|
self._model_handle = model_handel
|
||||||
|
self.__last_run_id = run_id
|
||||||
|
|
||||||
return self._model_handle
|
return self._model_handle
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExt
|
|||||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
||||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||||
|
from image_prediction.model_loader.loaders.loaders import get_mlflow_loader
|
||||||
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
|
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
|
||||||
from image_prediction.model_loader.loaders.mock import ModelLoaderMock
|
from image_prediction.model_loader.loaders.mock import ModelLoaderMock
|
||||||
|
|
||||||
@ -225,7 +226,7 @@ def model_loader(loader_type, monkeypatch, model_handle_mock, classes):
|
|||||||
monkeypatch.setattr(loader, "model", model_handle_mock)
|
monkeypatch.setattr(loader, "model", model_handle_mock)
|
||||||
monkeypatch.setattr(loader, "classes", classes)
|
monkeypatch.setattr(loader, "classes", classes)
|
||||||
elif loader_type == "mlflow":
|
elif loader_type == "mlflow":
|
||||||
loader = MlflowLoader("...")
|
loader = get_mlflow_loader()
|
||||||
monkeypatch.setattr(loader, "_model_handle", model_handle_mock)
|
monkeypatch.setattr(loader, "_model_handle", model_handle_mock)
|
||||||
else:
|
else:
|
||||||
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
|
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
|
||||||
|
|||||||
@ -8,6 +8,8 @@ from image_prediction.model_loading import load_model_and_classes
|
|||||||
@pytest.mark.parametrize("estimator_type", ["mock"])
|
@pytest.mark.parametrize("estimator_type", ["mock"])
|
||||||
@pytest.mark.parametrize("batch_size", [3])
|
@pytest.mark.parametrize("batch_size", [3])
|
||||||
def test_load_model_and_classes(model_loader, model_handle_mock, classes):
|
def test_load_model_and_classes(model_loader, model_handle_mock, classes):
|
||||||
model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader)
|
# Load twice to test caching logic
|
||||||
assert model_loaded == model_handle_mock
|
for _ in range(2):
|
||||||
assert np.all(classes_loaded == classes)
|
model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader)
|
||||||
|
assert model_loaded == model_handle_mock
|
||||||
|
assert np.all(classes_loaded == classes)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user