redoing model loading design
This commit is contained in:
parent
a1c7dd4a8d
commit
f60bafd007
@ -10,5 +10,9 @@ class UnknownModelLoader(ValueError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownDatabaseType(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class IncorrectInstantiation(RuntimeError):
|
class IncorrectInstantiation(RuntimeError):
|
||||||
pass
|
pass
|
||||||
|
|||||||
10
image_prediction/model_loader/database/connectors/mock.py
Normal file
10
image_prediction/model_loader/database/connectors/mock.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from image_prediction.model_loader.database.connector import DatabaseConnector
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseConnectorMock(DatabaseConnector):
|
||||||
|
|
||||||
|
def __init__(self, store: dict):
|
||||||
|
self.store = store
|
||||||
|
|
||||||
|
def get_object(self, identifier):
|
||||||
|
return self.store[identifier]
|
||||||
@ -1,12 +1,19 @@
|
|||||||
import abc
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from image_prediction.model_loader.database.connector import DatabaseConnector
|
||||||
|
|
||||||
|
|
||||||
class ModelLoader(abc.ABC):
|
class ModelLoader:
|
||||||
|
|
||||||
@abc.abstractmethod
|
def __init__(self, database_connector: DatabaseConnector):
|
||||||
def load_model(self, *args, **kwargs):
|
self.database_connector = database_connector
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@lru_cache(maxsize=None)
|
||||||
def load_classes(self, *args, **kwargs):
|
def __get_object(self, identifier):
|
||||||
pass
|
return self.database_connector.get_object(identifier)
|
||||||
|
|
||||||
|
def load_model(self, identifier):
|
||||||
|
return self.__get_object(identifier)["model"]
|
||||||
|
|
||||||
|
def load_classes(self, identifier):
|
||||||
|
return self.__get_object(identifier)["classes"]
|
||||||
|
|||||||
@ -1,101 +1,123 @@
|
|||||||
"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
|
# """This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
|
||||||
MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
|
# MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
|
||||||
need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
|
# need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
|
||||||
MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
|
# MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
|
||||||
non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
|
# non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
|
||||||
based solution or look into an alternative such as WandB or use a platform solution such as AWS.
|
# based solution or look into an alternative such as WandB or use a platform solution such as AWS.
|
||||||
"""
|
# """
|
||||||
import importlib
|
# import importlib
|
||||||
import json
|
# import json
|
||||||
import os
|
# import os
|
||||||
import warnings
|
# import warnings
|
||||||
|
# from typing import Mapping
|
||||||
import numpy as np
|
#
|
||||||
|
# import numpy as np
|
||||||
from image_prediction.exceptions import IncorrectInstantiation
|
# from funcy import rcompose
|
||||||
from image_prediction.model_loader.loader import ModelLoader
|
#
|
||||||
|
# from image_prediction.exceptions import IncorrectInstantiation
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
|
# from image_prediction.model_loader.loader import ModelLoader
|
||||||
|
#
|
||||||
import mlflow
|
# warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
|
||||||
|
#
|
||||||
|
# import mlflow
|
||||||
def load_object(object_path):
|
#
|
||||||
path_fragments = object_path.split(".")
|
#
|
||||||
|
# def load_object(object_path):
|
||||||
module_path = ".".join(path_fragments[:-1])
|
# path_fragments = object_path.split(".")
|
||||||
object_name = path_fragments[-1]
|
#
|
||||||
|
# module_path = ".".join(path_fragments[:-1])
|
||||||
module = importlib.import_module(module_path)
|
# object_name = path_fragments[-1]
|
||||||
return getattr(module, object_name)
|
#
|
||||||
|
# module = importlib.import_module(module_path)
|
||||||
|
# return getattr(module, object_name)
|
||||||
def to_local_path(uri):
|
#
|
||||||
return uri[7:]
|
#
|
||||||
|
# def to_local_path(uri):
|
||||||
|
# return uri[7:]
|
||||||
class MlflowModelReader:
|
#
|
||||||
|
#
|
||||||
def __init__(self, run_id, mlruns_dir=None):
|
# class MlflowModelReader:
|
||||||
mlflow.set_tracking_uri(mlruns_dir)
|
#
|
||||||
|
# def __init__(self, run_id, mlruns_dir=None):
|
||||||
self.run_id = run_id
|
# mlflow.set_tracking_uri(mlruns_dir)
|
||||||
self.run = mlflow.get_run(run_id)
|
#
|
||||||
self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir)
|
# self.run_id = run_id
|
||||||
|
# self.run = mlflow.get_run(run_id)
|
||||||
@staticmethod
|
# self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir)
|
||||||
def __correct_artifact_uri(run_artifact_uri, base_path):
|
#
|
||||||
_, suffix = run_artifact_uri.split("mlruns/")
|
# @staticmethod
|
||||||
return os.path.join(base_path, suffix)
|
# def __correct_artifact_uri(run_artifact_uri, base_path):
|
||||||
|
# _, suffix = run_artifact_uri.split("mlruns/")
|
||||||
def get_weights_path(self, prefix="tt"):
|
# return os.path.join(base_path, suffix)
|
||||||
path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5")
|
#
|
||||||
return path
|
# def get_weights_path(self, prefix="tt"):
|
||||||
|
# path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5")
|
||||||
def get_classes(self, prefix="tt"):
|
# return path
|
||||||
classes = json.loads(
|
#
|
||||||
self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"')
|
# def get_classes(self, prefix="tt"):
|
||||||
)
|
# classes = json.loads(
|
||||||
return classes
|
# self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"')
|
||||||
|
# )
|
||||||
def get_model_handle(self, base_weights=None):
|
# return classes
|
||||||
weights_path = self.get_weights_path()
|
#
|
||||||
model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip())
|
# def get_model_handle(self, base_weights=None):
|
||||||
model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights)
|
# weights_path = self.get_weights_path()
|
||||||
model_handle.load_top_weights(weights_path)
|
# model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip())
|
||||||
return model_handle
|
# model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights)
|
||||||
|
# model_handle.load_top_weights(weights_path)
|
||||||
|
# return model_handle
|
||||||
class MlflowLoader(ModelLoader):
|
#
|
||||||
|
#
|
||||||
def __init__(self, mlruns_dir):
|
# class PredictionModelHandle:
|
||||||
self.__mlruns_dir = mlruns_dir
|
# """Simplifies usage of ModelHandle instances for prediction purposes."""
|
||||||
self._model_handle = None
|
#
|
||||||
self.__last_run_id = None
|
# def __init__(self, model_handle, classes_readable: Mapping[int, str]):
|
||||||
self._base_weights = None
|
# self.__model_handle = model_handle
|
||||||
|
# self.__classes_readable = classes_readable
|
||||||
def load_model(self, run_id, base_weights=None):
|
#
|
||||||
|
# @property
|
||||||
if not base_weights:
|
# def classes(self):
|
||||||
|
# return self.__classes_readable
|
||||||
if not self._base_weights:
|
#
|
||||||
raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.")
|
# def predict(self, *args, **kwargs):
|
||||||
|
# predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
|
||||||
base_weights = self._base_weights
|
# return predict(*args, **kwargs)
|
||||||
|
#
|
||||||
if not self._model_handle and run_id == self.__last_run_id:
|
# def predict_proba(self, *args, **kwargs):
|
||||||
mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
|
# predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
|
||||||
model_handel = mlflow_reader.get_model_handle(base_weights)
|
# return predict(*args, **kwargs)
|
||||||
self._model_handle = model_handel
|
#
|
||||||
self.__last_run_id = run_id
|
#
|
||||||
|
# class MlflowLoader(ModelLoader):
|
||||||
return self._model_handle
|
#
|
||||||
|
# def __init__(self, mlruns_dir):
|
||||||
def load_classes(self, run_id):
|
# self.__mlruns_dir = mlruns_dir
|
||||||
model_handle = self.load_model(run_id)
|
# self._base_weights = None
|
||||||
|
#
|
||||||
classes = model_handle.model.classes_
|
# def load_model(self, run_id, base_weights=None) -> PredictionModelHandle:
|
||||||
classes_readable = np.array(model_handle.classes)
|
#
|
||||||
classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]]
|
# # TODO: refac https://stackoverflow.com/questions/42735421/how-to-restrict-object-instantiation-only-via-a-factory-in-python
|
||||||
|
# if not base_weights:
|
||||||
return classes_readable_aligned
|
#
|
||||||
|
# if not self._base_weights:
|
||||||
|
# raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.")
|
||||||
|
#
|
||||||
|
# base_weights = self._base_weights
|
||||||
|
#
|
||||||
|
# mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
|
||||||
|
# model_handel = mlflow_reader.get_model_handle(base_weights)
|
||||||
|
# model_handle = model_handel
|
||||||
|
# classes_readable = self.__load_classes(model_handle)
|
||||||
|
#
|
||||||
|
# model = PredictionModelHandle(model_handle, classes_readable)
|
||||||
|
#
|
||||||
|
# return model
|
||||||
|
#
|
||||||
|
# @staticmethod
|
||||||
|
# def __load_classes(model_handle):
|
||||||
|
#
|
||||||
|
# classes = model_handle.model.classes_
|
||||||
|
# classes_readable = np.array(model_handle.classes)
|
||||||
|
# classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]]
|
||||||
|
#
|
||||||
|
# return classes_readable_aligned
|
||||||
|
|||||||
@ -9,7 +9,3 @@ class ModelLoaderMock(ModelLoader):
|
|||||||
def load_model(self, identifier):
|
def load_model(self, identifier):
|
||||||
assert self.model is not None, "Set the model to be returned first via monkeypatching"
|
assert self.model is not None, "Set the model to be returned first via monkeypatching"
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def load_classes(self, identifier):
|
|
||||||
assert self.classes is not None, "Set the classes to be returned first via monkeypatching"
|
|
||||||
return self.classes
|
|
||||||
|
|||||||
@ -1,17 +0,0 @@
|
|||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
from image_prediction.locations import MLRUNS_DIR
|
|
||||||
from image_prediction.model_loader.loader import ModelLoader
|
|
||||||
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
|
|
||||||
|
|
||||||
ModelClassesPair = namedtuple("ModelClassesPair", ["model", "classes"])
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_classes(identifier, model_loader: ModelLoader = None) -> ModelClassesPair:
|
|
||||||
if not model_loader:
|
|
||||||
model_loader = MlflowLoader(MLRUNS_DIR)
|
|
||||||
|
|
||||||
model = model_loader.load_model(identifier)
|
|
||||||
classes = model_loader.load_classes(identifier)
|
|
||||||
|
|
||||||
return ModelClassesPair(model, classes)
|
|
||||||
@ -1,4 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
|
import string
|
||||||
import tempfile
|
import tempfile
|
||||||
from itertools import starmap
|
from itertools import starmap
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@ -13,13 +14,12 @@ from image_prediction.classifier.image_classifier import ImageClassifier
|
|||||||
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
|
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
|
||||||
from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock
|
from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock
|
||||||
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||||
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownModelLoader
|
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType
|
||||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
||||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||||
from image_prediction.model_loader.loaders.loaders import get_mlflow_loader
|
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
||||||
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
|
from image_prediction.model_loader.loader import ModelLoader
|
||||||
from image_prediction.model_loader.loaders.mock import ModelLoaderMock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -207,29 +207,79 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair):
|
|||||||
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type="png")
|
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type="png")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
# @pytest.fixture
|
||||||
def model_handle_mock(classes, classifier):
|
# def model_handle_mock(classes, classifier):
|
||||||
|
#
|
||||||
class ModelHandleMock:
|
# class ModelHandleMock:
|
||||||
|
#
|
||||||
def __init__(self, classes):
|
# def __init__(self, classes):
|
||||||
classifier.classes_ = np.array(list(range(len(classes))))
|
# classifier.classes_ = np.array(list(range(len(classes))))
|
||||||
self.classes = classes
|
# self.classes = classes
|
||||||
self.model = classifier
|
# self.model = classifier
|
||||||
|
#
|
||||||
return ModelHandleMock(classes)
|
# return ModelHandleMock(classes)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# @pytest.fixture
|
||||||
|
# def prediction_model_handle_mock(model_handle_mock, classes):
|
||||||
|
# return PredictionModelHandle(model_handle_mock, classes)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model_loader(loader_type, monkeypatch, model_handle_mock, classes):
|
def model():
|
||||||
if loader_type == "mock":
|
|
||||||
loader = ModelLoaderMock()
|
class Model:
|
||||||
monkeypatch.setattr(loader, "model", model_handle_mock)
|
|
||||||
monkeypatch.setattr(loader, "classes", classes)
|
@staticmethod
|
||||||
elif loader_type == "mlflow":
|
def predict(*args):
|
||||||
loader = get_mlflow_loader()
|
return True
|
||||||
monkeypatch.setattr(loader, "_model_handle", model_handle_mock)
|
|
||||||
|
@staticmethod
|
||||||
|
def predict_proba(*args):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return Model()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_database_record_identifier():
|
||||||
|
return "".join(random.sample(string.ascii_letters, k=10))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_database_record(model, classes):
|
||||||
|
return {"model": model, "classes": classes}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_database(model_database_record, model_database_record_identifier):
|
||||||
|
return {model_database_record_identifier: model_database_record}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def database_connector(database_type, model_database):
|
||||||
|
if database_type == "mock":
|
||||||
|
return DatabaseConnectorMock(model_database)
|
||||||
else:
|
else:
|
||||||
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
|
raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.")
|
||||||
|
|
||||||
return loader
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_loader(database_connector):
|
||||||
|
return ModelLoader(database_connector)
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.fixture
|
||||||
|
# def model_loader(loader_type, monkeypatch, model_handle_mock, classes):
|
||||||
|
# if loader_type == "mock":
|
||||||
|
# loader = ModelLoaderMock()
|
||||||
|
# monkeypatch.setattr(loader, "model", model_handle_mock)
|
||||||
|
#
|
||||||
|
# # elif loader_type == "mlflow":
|
||||||
|
# # loader = get_mlflow_loader()
|
||||||
|
# # monkeypatch.setattr(loader, "_model_handle", model_handle_mock)
|
||||||
|
#
|
||||||
|
# else:
|
||||||
|
# raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
|
||||||
|
#
|
||||||
|
# return loader
|
||||||
|
|||||||
@ -1,15 +1,19 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from image_prediction.model_loading import load_model_and_classes
|
|
||||||
|
|
||||||
|
# @pytest.mark.parametrize("loader_type", ["mock"])
|
||||||
|
# @pytest.mark.parametrize("estimator_type", ["mock"])
|
||||||
|
# @pytest.mark.parametrize("batch_size", [3])
|
||||||
|
# def test_load_model_and_classes(model_loader, model_handle_mock, classes):
|
||||||
|
# model_loaded, classes_loaded = model_loader.load_model_and_classes("an identifier")
|
||||||
|
# assert model_loaded == model_handle_mock
|
||||||
|
# assert np.all(classes_loaded == classes)
|
||||||
|
|
||||||
@pytest.mark.parametrize("loader_type", ["mock", "mlflow"])
|
@pytest.mark.parametrize("database_type", ["mock"])
|
||||||
@pytest.mark.parametrize("estimator_type", ["mock"])
|
def test_load_model_and_classes(model_loader, model_database_record_identifier, model, classes):
|
||||||
@pytest.mark.parametrize("batch_size", [3])
|
model_loaded = model_loader.load_model(model_database_record_identifier)
|
||||||
def test_load_model_and_classes(model_loader, model_handle_mock, classes):
|
classes_loaded = model_loader.load_classes(model_database_record_identifier)
|
||||||
# Load twice to test caching logic
|
|
||||||
for _ in range(2):
|
assert model_loaded == model
|
||||||
model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader)
|
assert classes_loaded == classes
|
||||||
assert model_loaded == model_handle_mock
|
|
||||||
assert np.all(classes_loaded == classes)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user