removed unneeded adapter derivatives and made estimator adapter abstract base class to normal class

This commit is contained in:
Matthias Bisping 2022-03-29 20:44:26 +02:00
parent 7340fb6dda
commit 3339ed2eab
5 changed files with 30 additions and 36 deletions

View File

@ -1,10 +1,9 @@
import abc
class EstimatorAdapter(abc.ABC):
class EstimatorAdapter:
def __init__(self, estimator):
self.estimator = estimator
@abc.abstractmethod
def predict(self, batch):
pass
return self.estimator.predict(batch)
def predict_proba(self, batch):
return self.estimator.predict_proba(batch)

View File

@ -1,11 +0,0 @@
import numpy as np
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
class KerasEstimatorAdapter(EstimatorAdapter):
def __init__(self, estimator):
super().__init__(estimator)
def predict(self, batch: np.array):
return self.estimator.predict(batch)

View File

@ -1,15 +0,0 @@
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
class EstimatorMock:
@staticmethod
def predict(batch):
return [None for _ in batch]
class EstimatorAdapterMock(EstimatorAdapter):
def __init__(self, estimator):
super().__init__(estimator=estimator)
def predict(self, batch):
return self.estimator.predict(batch)

View File

@ -0,0 +1,16 @@
from image_prediction.classifier.classifier import Classifier
from image_prediction.config import CONFIG
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.locations import MLRUNS_DIR
from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
from image_prediction.redai_adapter.mlflow import MlflowModelReader
class Pipeline:
def __init__(self):
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR)))
model = model_loader.load_model(CONFIG.service.run_id)
classes = model_loader.load_classes(CONFIG.service.run_id)
classifier = Classifier(EstimatorAdapter(model), classes)

View File

@ -11,8 +11,7 @@ from PIL import Image
from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType
from image_prediction.image_extractor.extractor import ImageMetadataPair
@ -48,12 +47,18 @@ def classifier(estimator_adapter, classes):
return classifier
class EstimatorMock:
@staticmethod
def predict(batch):
return [None for _ in batch]
@pytest.fixture
def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch):
if estimator_type == "mock":
estimator_adapter = EstimatorAdapterMock(EstimatorMock())
estimator_adapter = EstimatorAdapter(EstimatorMock())
elif estimator_type == "keras":
estimator_adapter = KerasEstimatorAdapter(keras_model)
estimator_adapter = EstimatorAdapter(keras_model)
else:
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")