refactoring of mlflow model loader
This commit is contained in:
parent
3b4c2a40b2
commit
6b58756103
@ -8,3 +8,7 @@ class UnknownImageExtractor(ValueError):
|
||||
|
||||
class UnknownModelLoader(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class IncorrectInstantiation(RuntimeError):
|
||||
pass
|
||||
|
||||
@ -4,9 +4,9 @@ import abc
|
||||
class ModelLoader(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_model(self, identifier):
|
||||
def load_model(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_classes(self, identifier):
|
||||
def load_classes(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
19
image_prediction/model_loader/loaders/loaders.py
Normal file
19
image_prediction/model_loader/loaders/loaders.py
Normal file
@ -0,0 +1,19 @@
|
||||
from image_prediction.exceptions import UnknownModelLoader
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
|
||||
|
||||
|
||||
def get_mlflow_loader():
|
||||
from image_prediction.locations import BASE_WEIGHTS
|
||||
from image_prediction.config import CONFIG
|
||||
|
||||
loader = MlflowLoader(CONFIG.service.run_id)
|
||||
loader._base_weights = BASE_WEIGHTS
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def get_model_loader(loader_type: type):
|
||||
if loader_type == MlflowLoader:
|
||||
return get_mlflow_loader()
|
||||
else:
|
||||
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
|
||||
@ -1,3 +1,10 @@
|
||||
"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
|
||||
MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
|
||||
need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
|
||||
MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
|
||||
non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
|
||||
based solution or look into an alternative such as WandB or use a platform solution such as AWS.
|
||||
"""
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
@ -5,7 +12,7 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from image_prediction.locations import BASE_WEIGHTS
|
||||
from image_prediction.exceptions import IncorrectInstantiation
|
||||
from image_prediction.model_loader.loader import ModelLoader
|
||||
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
|
||||
@ -64,12 +71,23 @@ class MlflowLoader(ModelLoader):
|
||||
def __init__(self, mlruns_dir):
|
||||
self.__mlruns_dir = mlruns_dir
|
||||
self._model_handle = None
|
||||
self.__last_run_id = None
|
||||
self._base_weights = None
|
||||
|
||||
def load_model(self, run_id):
|
||||
if not self._model_handle:
|
||||
def load_model(self, run_id, base_weights=None):
|
||||
|
||||
if not base_weights:
|
||||
|
||||
if not self._base_weights:
|
||||
raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.")
|
||||
|
||||
base_weights = self._base_weights
|
||||
|
||||
if not self._model_handle and run_id == self.__last_run_id:
|
||||
mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
|
||||
model_handel = mlflow_reader.get_model_handle(BASE_WEIGHTS)
|
||||
model_handel = mlflow_reader.get_model_handle(base_weights)
|
||||
self._model_handle = model_handel
|
||||
self.__last_run_id = run_id
|
||||
|
||||
return self._model_handle
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExt
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||
from image_prediction.model_loader.loaders.loaders import get_mlflow_loader
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
|
||||
from image_prediction.model_loader.loaders.mock import ModelLoaderMock
|
||||
|
||||
@ -225,7 +226,7 @@ def model_loader(loader_type, monkeypatch, model_handle_mock, classes):
|
||||
monkeypatch.setattr(loader, "model", model_handle_mock)
|
||||
monkeypatch.setattr(loader, "classes", classes)
|
||||
elif loader_type == "mlflow":
|
||||
loader = MlflowLoader("...")
|
||||
loader = get_mlflow_loader()
|
||||
monkeypatch.setattr(loader, "_model_handle", model_handle_mock)
|
||||
else:
|
||||
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
|
||||
|
||||
@ -8,6 +8,8 @@ from image_prediction.model_loading import load_model_and_classes
|
||||
@pytest.mark.parametrize("estimator_type", ["mock"])
|
||||
@pytest.mark.parametrize("batch_size", [3])
|
||||
def test_load_model_and_classes(model_loader, model_handle_mock, classes):
|
||||
model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader)
|
||||
assert model_loaded == model_handle_mock
|
||||
assert np.all(classes_loaded == classes)
|
||||
# Load twice to test caching logic
|
||||
for _ in range(2):
|
||||
model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader)
|
||||
assert model_loaded == model_handle_mock
|
||||
assert np.all(classes_loaded == classes)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user