redoing model loading design

This commit is contained in:
Matthias Bisping 2022-03-29 17:25:06 +02:00
parent a1c7dd4a8d
commit f60bafd007
8 changed files with 241 additions and 165 deletions

View File

@ -10,5 +10,9 @@ class UnknownModelLoader(ValueError):
pass pass
class UnknownDatabaseType(ValueError):
pass
class IncorrectInstantiation(RuntimeError): class IncorrectInstantiation(RuntimeError):
pass pass

View File

@ -0,0 +1,10 @@
from image_prediction.model_loader.database.connector import DatabaseConnector
class DatabaseConnectorMock(DatabaseConnector):
def __init__(self, store: dict):
self.store = store
def get_object(self, identifier):
return self.store[identifier]

View File

@ -1,12 +1,19 @@
import abc from functools import lru_cache
from image_prediction.model_loader.database.connector import DatabaseConnector
class ModelLoader(abc.ABC): class ModelLoader:
@abc.abstractmethod def __init__(self, database_connector: DatabaseConnector):
def load_model(self, *args, **kwargs): self.database_connector = database_connector
pass
@abc.abstractmethod @lru_cache(maxsize=None)
def load_classes(self, *args, **kwargs): def __get_object(self, identifier):
pass return self.database_connector.get_object(identifier)
def load_model(self, identifier):
return self.__get_object(identifier)["model"]
def load_classes(self, identifier):
return self.__get_object(identifier)["classes"]

View File

@ -1,101 +1,123 @@
"""This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and # """This module translates between the new ModelLoader API and the inconsistent and historically grown redai model and
MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the # MLflow API as well as the circumstance, that the model artifacts are currently not stored at a single place, due to the
need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the # need of loading the base weights of the pre-trained model, that became apparent at a later point than the design of the
MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a # MLflow storage and MlflowModelReader class; that is why the code in this module is so unclean. In the future, a
non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow # non-adhoc solution should be used that offers a clean API and storage solution. Either implement a well-designed MLflow
based solution or look into an alternative such as WandB or use a platform solution such as AWS. # based solution or look into an alternative such as WandB or use a platform solution such as AWS.
""" # """
import importlib # import importlib
import json # import json
import os # import os
import warnings # import warnings
# from typing import Mapping
import numpy as np #
# import numpy as np
from image_prediction.exceptions import IncorrectInstantiation # from funcy import rcompose
from image_prediction.model_loader.loader import ModelLoader #
# from image_prediction.exceptions import IncorrectInstantiation
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") # from image_prediction.model_loader.loader import ModelLoader
#
import mlflow # warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
#
# import mlflow
def load_object(object_path): #
path_fragments = object_path.split(".") #
# def load_object(object_path):
module_path = ".".join(path_fragments[:-1]) # path_fragments = object_path.split(".")
object_name = path_fragments[-1] #
# module_path = ".".join(path_fragments[:-1])
module = importlib.import_module(module_path) # object_name = path_fragments[-1]
return getattr(module, object_name) #
# module = importlib.import_module(module_path)
# return getattr(module, object_name)
def to_local_path(uri): #
return uri[7:] #
# def to_local_path(uri):
# return uri[7:]
class MlflowModelReader: #
#
def __init__(self, run_id, mlruns_dir=None): # class MlflowModelReader:
mlflow.set_tracking_uri(mlruns_dir) #
# def __init__(self, run_id, mlruns_dir=None):
self.run_id = run_id # mlflow.set_tracking_uri(mlruns_dir)
self.run = mlflow.get_run(run_id) #
self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir) # self.run_id = run_id
# self.run = mlflow.get_run(run_id)
@staticmethod # self.artifact_uri = self.__correct_artifact_uri(self.run.info.to_proto().artifact_uri, mlruns_dir)
def __correct_artifact_uri(run_artifact_uri, base_path): #
_, suffix = run_artifact_uri.split("mlruns/") # @staticmethod
return os.path.join(base_path, suffix) # def __correct_artifact_uri(run_artifact_uri, base_path):
# _, suffix = run_artifact_uri.split("mlruns/")
def get_weights_path(self, prefix="tt"): # return os.path.join(base_path, suffix)
path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5") #
return path # def get_weights_path(self, prefix="tt"):
# path = os.path.join(self.artifact_uri, prefix, "train_dev", "estimator", "weights.h5")
def get_classes(self, prefix="tt"): # return path
classes = json.loads( #
self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"') # def get_classes(self, prefix="tt"):
) # classes = json.loads(
return classes # self.run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"')
# )
def get_model_handle(self, base_weights=None): # return classes
weights_path = self.get_weights_path() #
model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip()) # def get_model_handle(self, base_weights=None):
model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights) # weights_path = self.get_weights_path()
model_handle.load_top_weights(weights_path) # model_handle_builder = load_object(self.run.data.params["model_handle_builder"].strip())
return model_handle # model_handle = model_handle_builder(self.get_classes(), base_weights=base_weights)
# model_handle.load_top_weights(weights_path)
# return model_handle
class MlflowLoader(ModelLoader): #
#
def __init__(self, mlruns_dir): # class PredictionModelHandle:
self.__mlruns_dir = mlruns_dir # """Simplifies usage of ModelHandle instances for prediction purposes."""
self._model_handle = None #
self.__last_run_id = None # def __init__(self, model_handle, classes_readable: Mapping[int, str]):
self._base_weights = None # self.__model_handle = model_handle
# self.__classes_readable = classes_readable
def load_model(self, run_id, base_weights=None): #
# @property
if not base_weights: # def classes(self):
# return self.__classes_readable
if not self._base_weights: #
raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.") # def predict(self, *args, **kwargs):
# predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
base_weights = self._base_weights # return predict(*args, **kwargs)
#
if not self._model_handle and run_id == self.__last_run_id: # def predict_proba(self, *args, **kwargs):
mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir) # predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
model_handel = mlflow_reader.get_model_handle(base_weights) # return predict(*args, **kwargs)
self._model_handle = model_handel #
self.__last_run_id = run_id #
# class MlflowLoader(ModelLoader):
return self._model_handle #
# def __init__(self, mlruns_dir):
def load_classes(self, run_id): # self.__mlruns_dir = mlruns_dir
model_handle = self.load_model(run_id) # self._base_weights = None
#
classes = model_handle.model.classes_ # def load_model(self, run_id, base_weights=None) -> PredictionModelHandle:
classes_readable = np.array(model_handle.classes) #
classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]] # # TODO: refac https://stackoverflow.com/questions/42735421/how-to-restrict-object-instantiation-only-via-a-factory-in-python
# if not base_weights:
return classes_readable_aligned #
# if not self._base_weights:
# raise IncorrectInstantiation("MlflowReader needs to be initialized via get_model_loader.")
#
# base_weights = self._base_weights
#
# mlflow_reader = MlflowModelReader(run_id, mlruns_dir=self.__mlruns_dir)
# model_handel = mlflow_reader.get_model_handle(base_weights)
# model_handle = model_handel
# classes_readable = self.__load_classes(model_handle)
#
# model = PredictionModelHandle(model_handle, classes_readable)
#
# return model
#
# @staticmethod
# def __load_classes(model_handle):
#
# classes = model_handle.model.classes_
# classes_readable = np.array(model_handle.classes)
# classes_readable_aligned = classes_readable[classes[list(range(len(classes)))]]
#
# return classes_readable_aligned

View File

@ -9,7 +9,3 @@ class ModelLoaderMock(ModelLoader):
def load_model(self, identifier): def load_model(self, identifier):
assert self.model is not None, "Set the model to be returned first via monkeypatching" assert self.model is not None, "Set the model to be returned first via monkeypatching"
return self.model return self.model
def load_classes(self, identifier):
assert self.classes is not None, "Set the classes to be returned first via monkeypatching"
return self.classes

View File

@ -1,17 +0,0 @@
from collections import namedtuple
from image_prediction.locations import MLRUNS_DIR
from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowLoader
ModelClassesPair = namedtuple("ModelClassesPair", ["model", "classes"])
def load_model_and_classes(identifier, model_loader: ModelLoader = None) -> ModelClassesPair:
if not model_loader:
model_loader = MlflowLoader(MLRUNS_DIR)
model = model_loader.load_model(identifier)
classes = model_loader.load_classes(identifier)
return ModelClassesPair(model, classes)

View File

@ -1,4 +1,5 @@
import random import random
import string
import tempfile import tempfile
from itertools import starmap from itertools import starmap
from operator import itemgetter from operator import itemgetter
@ -13,13 +14,12 @@ from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownModelLoader from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType
from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.model_loader.loaders.loaders import get_mlflow_loader from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
from image_prediction.model_loader.loaders.mlflow import MlflowLoader from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mock import ModelLoaderMock
@pytest.fixture @pytest.fixture
@ -207,29 +207,79 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair):
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type="png") pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type="png")
@pytest.fixture # @pytest.fixture
def model_handle_mock(classes, classifier): # def model_handle_mock(classes, classifier):
#
class ModelHandleMock: # class ModelHandleMock:
#
def __init__(self, classes): # def __init__(self, classes):
classifier.classes_ = np.array(list(range(len(classes)))) # classifier.classes_ = np.array(list(range(len(classes))))
self.classes = classes # self.classes = classes
self.model = classifier # self.model = classifier
#
return ModelHandleMock(classes) # return ModelHandleMock(classes)
#
#
# @pytest.fixture
# def prediction_model_handle_mock(model_handle_mock, classes):
# return PredictionModelHandle(model_handle_mock, classes)
@pytest.fixture @pytest.fixture
def model_loader(loader_type, monkeypatch, model_handle_mock, classes): def model():
if loader_type == "mock":
loader = ModelLoaderMock() class Model:
monkeypatch.setattr(loader, "model", model_handle_mock)
monkeypatch.setattr(loader, "classes", classes) @staticmethod
elif loader_type == "mlflow": def predict(*args):
loader = get_mlflow_loader() return True
monkeypatch.setattr(loader, "_model_handle", model_handle_mock)
@staticmethod
def predict_proba(*args):
return True
return Model()
@pytest.fixture
def model_database_record_identifier():
return "".join(random.sample(string.ascii_letters, k=10))
@pytest.fixture
def model_database_record(model, classes):
return {"model": model, "classes": classes}
@pytest.fixture
def model_database(model_database_record, model_database_record_identifier):
return {model_database_record_identifier: model_database_record}
@pytest.fixture
def database_connector(database_type, model_database):
if database_type == "mock":
return DatabaseConnectorMock(model_database)
else: else:
raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.") raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.")
return loader
@pytest.fixture
def model_loader(database_connector):
return ModelLoader(database_connector)
# @pytest.fixture
# def model_loader(loader_type, monkeypatch, model_handle_mock, classes):
# if loader_type == "mock":
# loader = ModelLoaderMock()
# monkeypatch.setattr(loader, "model", model_handle_mock)
#
# # elif loader_type == "mlflow":
# # loader = get_mlflow_loader()
# # monkeypatch.setattr(loader, "_model_handle", model_handle_mock)
#
# else:
# raise UnknownModelLoader(f"No model loader for type {loader_type} was specified.")
#
# return loader

View File

@ -1,15 +1,19 @@
import numpy as np import numpy as np
import pytest import pytest
from image_prediction.model_loading import load_model_and_classes
# @pytest.mark.parametrize("loader_type", ["mock"])
# @pytest.mark.parametrize("estimator_type", ["mock"])
# @pytest.mark.parametrize("batch_size", [3])
# def test_load_model_and_classes(model_loader, model_handle_mock, classes):
# model_loaded, classes_loaded = model_loader.load_model_and_classes("an identifier")
# assert model_loaded == model_handle_mock
# assert np.all(classes_loaded == classes)
@pytest.mark.parametrize("loader_type", ["mock", "mlflow"]) @pytest.mark.parametrize("database_type", ["mock"])
@pytest.mark.parametrize("estimator_type", ["mock"]) def test_load_model_and_classes(model_loader, model_database_record_identifier, model, classes):
@pytest.mark.parametrize("batch_size", [3]) model_loaded = model_loader.load_model(model_database_record_identifier)
def test_load_model_and_classes(model_loader, model_handle_mock, classes): classes_loaded = model_loader.load_classes(model_database_record_identifier)
# Load twice to test caching logic
for _ in range(2): assert model_loaded == model
model_loaded, classes_loaded = load_model_and_classes("some random identifier", model_loader=model_loader) assert classes_loaded == classes
assert model_loaded == model_handle_mock
assert np.all(classes_loaded == classes)