removed unneeded adapter derivatives and made estimator adapter abstract base class to normal class
This commit is contained in:
parent
7340fb6dda
commit
3339ed2eab
@ -1,10 +1,9 @@
|
||||
import abc
|
||||
|
||||
|
||||
class EstimatorAdapter(abc.ABC):
|
||||
class EstimatorAdapter:
|
||||
def __init__(self, estimator):
|
||||
self.estimator = estimator
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict(self, batch):
|
||||
pass
|
||||
return self.estimator.predict(batch)
|
||||
|
||||
def predict_proba(self, batch):
|
||||
return self.estimator.predict_proba(batch)
|
||||
|
||||
@ -1,11 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
|
||||
|
||||
class KerasEstimatorAdapter(EstimatorAdapter):
|
||||
def __init__(self, estimator):
|
||||
super().__init__(estimator)
|
||||
|
||||
def predict(self, batch: np.array):
|
||||
return self.estimator.predict(batch)
|
||||
@ -1,15 +0,0 @@
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
|
||||
|
||||
class EstimatorMock:
|
||||
@staticmethod
|
||||
def predict(batch):
|
||||
return [None for _ in batch]
|
||||
|
||||
|
||||
class EstimatorAdapterMock(EstimatorAdapter):
|
||||
def __init__(self, estimator):
|
||||
super().__init__(estimator=estimator)
|
||||
|
||||
def predict(self, batch):
|
||||
return self.estimator.predict(batch)
|
||||
16
image_prediction/pipeline.py
Normal file
16
image_prediction/pipeline.py
Normal file
@ -0,0 +1,16 @@
|
||||
from image_prediction.classifier.classifier import Classifier
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
from image_prediction.locations import MLRUNS_DIR
|
||||
from image_prediction.model_loader.loader import ModelLoader
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
||||
|
||||
|
||||
class Pipeline:
|
||||
|
||||
def __init__(self):
|
||||
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR)))
|
||||
model = model_loader.load_model(CONFIG.service.run_id)
|
||||
classes = model_loader.load_classes(CONFIG.service.run_id)
|
||||
classifier = Classifier(EstimatorAdapter(model), classes)
|
||||
@ -11,8 +11,7 @@ from PIL import Image
|
||||
|
||||
from image_prediction.classifier.classifier import Classifier
|
||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
|
||||
from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
@ -48,12 +47,18 @@ def classifier(estimator_adapter, classes):
|
||||
return classifier
|
||||
|
||||
|
||||
class EstimatorMock:
|
||||
@staticmethod
|
||||
def predict(batch):
|
||||
return [None for _ in batch]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch):
|
||||
if estimator_type == "mock":
|
||||
estimator_adapter = EstimatorAdapterMock(EstimatorMock())
|
||||
estimator_adapter = EstimatorAdapter(EstimatorMock())
|
||||
elif estimator_type == "keras":
|
||||
estimator_adapter = KerasEstimatorAdapter(keras_model)
|
||||
estimator_adapter = EstimatorAdapter(keras_model)
|
||||
else:
|
||||
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user