redoing model loading design
This commit is contained in:
parent
a1c7dd4a8d
commit
f60bafd007
@ -10,5 +10,9 @@ class UnknownModelLoader(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownDatabaseType(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class IncorrectInstantiation(RuntimeError):
|
||||
pass
|
||||
|
||||
10
image_prediction/model_loader/database/connectors/mock.py
Normal file
10
image_prediction/model_loader/database/connectors/mock.py
Normal file
@ -0,0 +1,10 @@
|
||||
from image_prediction.model_loader.database.connector import DatabaseConnector
|
||||
|
||||
|
||||
class DatabaseConnectorMock(DatabaseConnector):
|
||||
|
||||
def __init__(self, store: dict):
|
||||
self.store = store
|
||||
|
||||
def get_object(self, identifier):
|
||||
return self.store[identifier]
|
||||
@ -1,12 +1,19 @@
|
||||
import abc
|
||||
from functools import lru_cache
|
||||
|
||||
from image_prediction.model_loader.database.connector import DatabaseConnector
|
||||
|
||||
|
||||
class ModelLoader(abc.ABC):
|
||||
class ModelLoader:
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_model(self, *args, **kwargs):
|
||||
pass
|
||||
def __init__(self, database_connector: DatabaseConnector):
|
||||
self.database_connector = database_connector
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_classes(self, *args, **kwargs):
|
||||
pass
|
||||
@lru_cache(maxsize=None)
|
||||
def __get_object(self, identifier):
|
||||
return self.database_connector.get_object(identifier)
|
||||
|
||||
def load_model(self, identifier):
|
||||
return self.__get_object(identifier)["model"]
|
||||
|
||||
def load_classes(self, identifier):
|
||||
return self.__get_object(identifier)["classes"]
|
||||
|
||||
@ -1,101 +1,123 @@
|
||||
"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
|
||||
MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
|
||||
need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
|
||||
MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
|
||||
non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
|
||||
based solution or look into an alternative such as WandB or use a platform solution such as AWS.
|
||||
"""
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from image_prediction.exceptions import IncorrectInstantiation
|
||||
from image_prediction.model_loader.loader import ModelLoader
|
||||
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
|
||||
|
||||
import mlflow
|
||||
|
||||
|
||||
def load_object(object_path):
|
||||
path_fragments = object_path.split(".")
|
||||
|
||||
module_path = ".".join(path_fragments[:-1])
|
||||
object_name = path_fragments[-1]
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, object_name)
|
||||
|
||||
|
||||
def to_local_path(uri):
|
||||
return uri[7:]
|
||||
|
||||
|
||||
class MlflowModelReader:
|
||||
|
||||
def __init__(self, run_id, mlruns_dir=None):
|
||||
mlflow.set_tracking_uri(mlruns_dir)
|
||||
|
||||
self.run_id = run_id
|
||||
self.run = mlflow.get_run(run_id)
|
||||
self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir)
|
||||
|
||||
@staticmethod
|
||||
def __correct_artifact_uri(run_artifact_uri, base_path):
|
||||
_, suffix = run_artifact_uri.split("mlruns/")
|
||||
return os.path.join(base_path, suffix)
|
||||
|
||||
def get_weights_path(self, prefix="tt"):
|
||||
path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5")
|
||||
return path
|
||||
|
||||
def get_classes(self, prefix="tt"):
|
||||
classes = json.loads(
|
||||
self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"')
|
||||
)
|
||||
return classes
|
||||
|
||||
def get_model_handle(self, base_weights=None):
|
||||
weights_path = self.get_weights_path()
|
||||
model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip())
|
||||
model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights)
|
||||
model_handle.load_top_weights(weights_path)
|
||||
return model_handle
|
||||
|
||||
|
||||
class MlflowLoader(ModelLoader):
|
||||
|
||||
def __init__(self, mlruns_dir):
|
||||
self.__mlruns_dir = mlruns_dir
|
||||
self._model_handle = None
|
||||
self.__last_run_id = None
|
||||
self._base_weights = None
|
||||
|
||||
def load_model(self, run_id, base_weights=None):
|
||||
|
||||
if not base_weights:
|
||||
|
||||
if not self._base_weights:
|
||||
raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.")
|
||||
|
||||
base_weights = self._base_weights
|
||||
|
||||
if not self._model_handle and run_id == self.__last_run_id:
|
||||
mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
|
||||
model_handel = mlflow_reader.get_model_handle(base_weights)
|
||||
self._model_handle = model_handel
|
||||
self.__last_run_id = run_id
|
||||
|
||||
return self._model_handle
|
||||
|
||||
def load_classes(self, run_id):
|
||||
model_handle = self.load_model(run_id)
|
||||
|
||||
classes = model_handle.model.classes_
|
||||
classes_readable = np.array(model_handle.classes)
|
||||
classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]]
|
||||
|
||||
return classes_readable_aligned
|
||||
# """This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
|
||||
# MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
|
||||
# need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
|
||||
# MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
|
||||
# non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
|
||||
# based solution or look into an alternative such as WandB or use a platform solution such as AWS.
|
||||
# """
|
||||
# import importlib
|
||||
# import json
|
||||
# import os
|
||||
# import warnings
|
||||
# from typing import Mapping
|
||||
#
|
||||
# import numpy as np
|
||||
# from funcy import rcompose
|
||||
#
|
||||
# from image_prediction.exceptions import IncorrectInstantiation
|
||||
# from image_prediction.model_loader.loader import ModelLoader
|
||||
#
|
||||
# warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
|
||||
#
|
||||
# import mlflow
|
||||
#
|
||||
#
|
||||
# def load_object(object_path):
|
||||
# path_fragments = object_path.split(".")
|
||||
#
|
||||
# module_path = ".".join(path_fragments[:-1])
|
||||
# object_name = path_fragments[-1]
|
||||
#
|
||||
# module = importlib.import_module(module_path)
|
||||
# return getattr(module, object_name)
|
||||
#
|
||||
#
|
||||
# def to_local_path(uri):
|
||||
# return uri[7:]
|
||||
#
|
||||
#
|
||||
# class MlflowModelReader:
|
||||
#
|
||||
# def __init__(self, run_id, mlruns_dir=None):
|
||||
# mlflow.set_tracking_uri(mlruns_dir)
|
||||
#
|
||||
# self.run_id = run_id
|
||||
# self.run = mlflow.get_run(run_id)
|
||||
# self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir)
|
||||
#
|
||||
# @staticmethod
|
||||
# def __correct_artifact_uri(run_artifact_uri, base_path):
|
||||
# _, suffix = run_artifact_uri.split("mlruns/")
|
||||
# return os.path.join(base_path, suffix)
|
||||
#
|
||||
# def get_weights_path(self, prefix="tt"):
|
||||
# path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5")
|
||||
# return path
|
||||
#
|
||||
# def get_classes(self, prefix="tt"):
|
||||
# classes = json.loads(
|
||||
# self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"')
|
||||
# )
|
||||
# return classes
|
||||
#
|
||||
# def get_model_handle(self, base_weights=None):
|
||||
# weights_path = self.get_weights_path()
|
||||
# model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip())
|
||||
# model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights)
|
||||
# model_handle.load_top_weights(weights_path)
|
||||
# return model_handle
|
||||
#
|
||||
#
|
||||
# class PredictionModelHandle:
|
||||
# """Simplifies usage of ModelHandle instances for prediction purposes."""
|
||||
#
|
||||
# def __init__(self, model_handle, classes_readable: Mapping[int, str]):
|
||||
# self.__model_handle = model_handle
|
||||
# self.__classes_readable = classes_readable
|
||||
#
|
||||
# @property
|
||||
# def classes(self):
|
||||
# return self.__classes_readable
|
||||
#
|
||||
# def predict(self, *args, **kwargs):
|
||||
# predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
|
||||
# return predict(*args, **kwargs)
|
||||
#
|
||||
# def predict_proba(self, *args, **kwargs):
|
||||
# predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
|
||||
# return predict(*args, **kwargs)
|
||||
#
|
||||
#
|
||||
# class MlflowLoader(ModelLoader):
|
||||
#
|
||||
# def __init__(self, mlruns_dir):
|
||||
# self.__mlruns_dir = mlruns_dir
|
||||
# self._base_weights = None
|
||||
#
|
||||
# def load_model(self, run_id, base_weights=None) -> PredictionModelHandle:
|
||||
#
|
||||
# # TODO: refac https://stackoverflow.com/questions/42735421/how-to-restrict-object-instantiation-only-via-a-factory-in-python
|
||||
# if not base_weights:
|
||||
#
|
||||
# if not self._base_weights:
|
||||
# raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.")
|
||||
#
|
||||
# base_weights = self._base_weights
|
||||
#
|
||||
# mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
|
||||
# model_handel = mlflow_reader.get_model_handle(base_weights)
|
||||
# model_handle = model_handel
|
||||
# classes_readable = self.__load_classes(model_handle)
|
||||
#
|
||||
# model = PredictionModelHandle(model_handle, classes_readable)
|
||||
#
|
||||
# return model
|
||||
#
|
||||
# @staticmethod
|
||||
# def __load_classes(model_handle):
|
||||
#
|
||||
# classes = model_handle.model.classes_
|
||||
# classes_readable = np.array(model_handle.classes)
|
||||
# classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]]
|
||||
#
|
||||
# return classes_readable_aligned
|
||||
|
||||
@ -9,7 +9,3 @@ class ModelLoaderMock(ModelLoader):
|
||||
def load_model(self, identifier):
|
||||
assert self.model is not None, "Set the model to be returned first via monkeypatching"
|
||||
return self.model
|
||||
|
||||
def load_classes(self, identifier):
|
||||
assert self.classes is not None, "Set the classes to be returned first via monkeypatching"
|
||||
return self.classes
|
||||
|
||||
@ -1,17 +0,0 @@
|
||||
from collections import namedtuple
|
||||
|
||||
from image_prediction.locations import MLRUNS_DIR
|
||||
from image_prediction.model_loader.loader import ModelLoader
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
|
||||
|
||||
ModelClassesPair = namedtuple("ModelClassesPair", ["model", "classes"])
|
||||
|
||||
|
||||
def load_model_and_classes(identifier, model_loader: ModelLoader = None) -> ModelClassesPair:
|
||||
if not model_loader:
|
||||
model_loader = MlflowLoader(MLRUNS_DIR)
|
||||
|
||||
model = model_loader.load_model(identifier)
|
||||
classes = model_loader.load_classes(identifier)
|
||||
|
||||
return ModelClassesPair(model, classes)
|
||||
@ -1,4 +1,5 @@
|
||||
import random
|
||||
import string
|
||||
import tempfile
|
||||
from itertools import starmap
|
||||
from operator import itemgetter
|
||||
@ -13,13 +14,12 @@ from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
|
||||
from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock
|
||||
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownModelLoader
|
||||
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||
from image_prediction.model_loader.loaders.loaders import get_mlflow_loader
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
|
||||
from image_prediction.model_loader.loaders.mock import ModelLoaderMock
|
||||
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
||||
from image_prediction.model_loader.loader import ModelLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -207,29 +207,79 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair):
|
||||
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type="png")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_handle_mock(classes, classifier):
|
||||
|
||||
class ModelHandleMock:
|
||||
|
||||
def __init__(self, classes):
|
||||
classifier.classes_ = np.array(list(range(len(classes))))
|
||||
self.classes = classes
|
||||
self.model = classifier
|
||||
|
||||
return ModelHandleMock(classes)
|
||||
# @pytest.fixture
|
||||
# def model_handle_mock(classes, classifier):
|
||||
#
|
||||
# class ModelHandleMock:
|
||||
#
|
||||
# def __init__(self, classes):
|
||||
# classifier.classes_ = np.array(list(range(len(classes))))
|
||||
# self.classes = classes
|
||||
# self.model = classifier
|
||||
#
|
||||
# return ModelHandleMock(classes)
|
||||
#
|
||||
#
|
||||
# @pytest.fixture
|
||||
# def prediction_model_handle_mock(model_handle_mock, classes):
|
||||
# return PredictionModelHandle(model_handle_mock, classes)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_loader(loader_type, monkeypatch, model_handle_mock, classes):
|
||||
if loader_type == "mock":
|
||||
loader = ModelLoaderMock()
|
||||
monkeypatch.setattr(loader, "model", model_handle_mock)
|
||||
monkeypatch.setattr(loader, "classes", classes)
|
||||
elif loader_type == "mlflow":
|
||||
loader = get_mlflow_loader()
|
||||
monkeypatch.setattr(loader, "_model_handle", model_handle_mock)
|
||||
def model():
|
||||
|
||||
class Model:
|
||||
|
||||
@staticmethod
|
||||
def predict(*args):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def predict_proba(*args):
|
||||
return True
|
||||
|
||||
return Model()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_database_record_identifier():
|
||||
return "".join(random.sample(string.ascii_letters, k=10))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_database_record(model, classes):
|
||||
return {"model": model, "classes": classes}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_database(model_database_record, model_database_record_identifier):
|
||||
return {model_database_record_identifier: model_database_record}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def database_connector(database_type, model_database):
|
||||
if database_type == "mock":
|
||||
return DatabaseConnectorMock(model_database)
|
||||
else:
|
||||
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
|
||||
raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.")
|
||||
|
||||
return loader
|
||||
|
||||
@pytest.fixture
|
||||
def model_loader(database_connector):
|
||||
return ModelLoader(database_connector)
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def model_loader(loader_type, monkeypatch, model_handle_mock, classes):
|
||||
# if loader_type == "mock":
|
||||
# loader = ModelLoaderMock()
|
||||
# monkeypatch.setattr(loader, "model", model_handle_mock)
|
||||
#
|
||||
# # elif loader_type == "mlflow":
|
||||
# # loader = get_mlflow_loader()
|
||||
# # monkeypatch.setattr(loader, "_model_handle", model_handle_mock)
|
||||
#
|
||||
# else:
|
||||
# raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
|
||||
#
|
||||
# return loader
|
||||
|
||||
@ -1,15 +1,19 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from image_prediction.model_loading import load_model_and_classes
|
||||
|
||||
# @pytest.mark.parametrize("loader_type", ["mock"])
|
||||
# @pytest.mark.parametrize("estimator_type", ["mock"])
|
||||
# @pytest.mark.parametrize("batch_size", [3])
|
||||
# def test_load_model_and_classes(model_loader, model_handle_mock, classes):
|
||||
# model_loaded, classes_loaded = model_loader.load_model_and_classes("an identifier")
|
||||
# assert model_loaded == model_handle_mock
|
||||
# assert np.all(classes_loaded == classes)
|
||||
|
||||
@pytest.mark.parametrize("loader_type", ["mock", "mlflow"])
|
||||
@pytest.mark.parametrize("estimator_type", ["mock"])
|
||||
@pytest.mark.parametrize("batch_size", [3])
|
||||
def test_load_model_and_classes(model_loader, model_handle_mock, classes):
|
||||
# Load twice to test caching logic
|
||||
for _ in range(2):
|
||||
model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader)
|
||||
assert model_loaded == model_handle_mock
|
||||
assert np.all(classes_loaded == classes)
|
||||
@pytest.mark.parametrize("database_type", ["mock"])
|
||||
def test_load_model_and_classes(model_loader, model_database_record_identifier, model, classes):
|
||||
model_loaded = model_loader.load_model(model_database_record_identifier)
|
||||
classes_loaded = model_loader.load_classes(model_database_record_identifier)
|
||||
|
||||
assert model_loaded == model
|
||||
assert classes_loaded == classes
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user