added keras estimator wrapper
This commit is contained in:
parent
8b7293be09
commit
7834a65ff5
11
image_prediction/estimator/estimators/keras.py
Normal file
11
image_prediction/estimator/estimators/keras.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from image_prediction.estimator.estimator import Estimator
|
||||||
|
|
||||||
|
|
||||||
|
class KerasEstimator(Estimator):
|
||||||
|
def __init__(self, estimator):
|
||||||
|
super().__init__(estimator)
|
||||||
|
|
||||||
|
def predict(self, batch: np.array):
|
||||||
|
self.estimator.predict(batch)
|
||||||
@ -1,10 +1,16 @@
|
|||||||
from image_prediction.estimator.estimator import Estimator
|
from image_prediction.estimator.estimator import Estimator
|
||||||
|
|
||||||
|
|
||||||
|
class DummyEstimator:
|
||||||
|
@staticmethod
|
||||||
|
def predict(_):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class EstimatorMock(Estimator):
|
class EstimatorMock(Estimator):
|
||||||
|
|
||||||
def __init__(self, estimator=lambda x: x):
|
def __init__(self, estimator):
|
||||||
super().__init__(estimator=estimator)
|
super().__init__(estimator=estimator)
|
||||||
|
|
||||||
def predict(self, batch):
|
def predict(self, batch):
|
||||||
return self.estimator(batch)
|
return self.estimator.predict(batch)
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
from typing import Mapping, List
|
from typing import Mapping, List
|
||||||
|
|
||||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class ServiceEstimator:
|
class ServiceEstimator:
|
||||||
@ -9,4 +12,9 @@ class ServiceEstimator:
|
|||||||
self.__classes = classes
|
self.__classes = classes
|
||||||
|
|
||||||
def predict(self, batch) -> List[str]:
|
def predict(self, batch) -> List[str]:
|
||||||
|
|
||||||
|
if batch.shape[0] == 0:
|
||||||
|
logger.warning("ServiceEstimator received empty batch")
|
||||||
|
return []
|
||||||
|
|
||||||
return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
|
return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
|
||||||
|
|||||||
@ -3,8 +3,9 @@ import logging
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from image_prediction.estimator.estimators.mock import EstimatorMock
|
|
||||||
from image_prediction.estimator.adapter.patch import EstimatorAdapterPatch
|
from image_prediction.estimator.adapter.patch import EstimatorAdapterPatch
|
||||||
|
from image_prediction.estimator.estimators.keras import KerasEstimator
|
||||||
|
from image_prediction.estimator.estimators.mock import EstimatorMock, DummyEstimator
|
||||||
from image_prediction.service_estimator.service_estimator import ServiceEstimator
|
from image_prediction.service_estimator.service_estimator import ServiceEstimator
|
||||||
from image_prediction.utils import get_logger
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
@ -13,9 +14,33 @@ logger.setLevel(logging.DEBUG)
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def estimator(estimator_type):
|
def input_size():
|
||||||
|
return 10, 15
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def keras_model(input_size):
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
|
||||||
|
from tensorflow import keras
|
||||||
|
|
||||||
|
inputs = keras.Input(shape=input_size)
|
||||||
|
dense = keras.layers.Dense(64, activation="relu")
|
||||||
|
outputs = keras.layers.Dense(10)(dense(inputs))
|
||||||
|
model = keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
model.compile()
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def estimator(estimator_type, keras_model):
|
||||||
if estimator_type == "mock":
|
if estimator_type == "mock":
|
||||||
return EstimatorMock()
|
return EstimatorMock(DummyEstimator())
|
||||||
|
if estimator_type == "keras":
|
||||||
|
return KerasEstimator(keras_model)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -26,8 +51,8 @@ def estimator_adapter(output_batch, estimator):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def input_batch(batch_size, classes):
|
def input_batch(batch_size, classes, input_size):
|
||||||
return np.random.normal(size=(batch_size, 10, 15))
|
return np.random.normal(size=(batch_size, *input_size))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -54,7 +79,7 @@ def service_estimator(estimator_adapter, classes):
|
|||||||
return ServiceEstimator(estimator_adapter, classes)
|
return ServiceEstimator(estimator_adapter, classes)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("estimator_type", ["mock"], scope="session")
|
@pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
|
||||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
||||||
def test_predict(service_estimator, input_batch, expected_predictions):
|
def test_predict(service_estimator, input_batch, expected_predictions):
|
||||||
predictions = service_estimator.predict(input_batch)
|
predictions = service_estimator.predict(input_batch)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user