From 7834a65ff53288632f5fb0b64075847a07f7b02a Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Fri, 25 Mar 2022 14:46:04 +0100 Subject: [PATCH] added keras estimator wrapper --- .../estimator/estimators/keras.py | 11 ++++++ image_prediction/estimator/estimators/mock.py | 10 ++++- .../service_estimator/service_estimator.py | 8 ++++ test/unit_tests/service_estimator_test.py | 37 ++++++++++++++++--- 4 files changed, 58 insertions(+), 8 deletions(-) create mode 100644 image_prediction/estimator/estimators/keras.py diff --git a/image_prediction/estimator/estimators/keras.py b/image_prediction/estimator/estimators/keras.py new file mode 100644 index 0000000..8661c85 --- /dev/null +++ b/image_prediction/estimator/estimators/keras.py @@ -0,0 +1,11 @@ +import numpy as np + +from image_prediction.estimator.estimator import Estimator + + +class KerasEstimator(Estimator): + def __init__(self, estimator): + super().__init__(estimator) + + def predict(self, batch: np.array): + self.estimator.predict(batch) diff --git a/image_prediction/estimator/estimators/mock.py b/image_prediction/estimator/estimators/mock.py index 3076f48..c642489 100644 --- a/image_prediction/estimator/estimators/mock.py +++ b/image_prediction/estimator/estimators/mock.py @@ -1,10 +1,16 @@ from image_prediction.estimator.estimator import Estimator +class DummyEstimator: + @staticmethod + def predict(_): + return True + + class EstimatorMock(Estimator): - def __init__(self, estimator=lambda x: x): + def __init__(self, estimator): super().__init__(estimator=estimator) def predict(self, batch): - return self.estimator(batch) + return self.estimator.predict(batch) diff --git a/image_prediction/service_estimator/service_estimator.py b/image_prediction/service_estimator/service_estimator.py index f779b85..623ca01 100644 --- a/image_prediction/service_estimator/service_estimator.py +++ b/image_prediction/service_estimator/service_estimator.py @@ -1,6 +1,9 @@ from typing import Mapping, List from image_prediction.estimator.adapter.adapter import EstimatorAdapter +from image_prediction.utils import get_logger + +logger = get_logger() class ServiceEstimator: @@ -9,4 +12,9 @@ class ServiceEstimator: self.__classes = classes def predict(self, batch) -> List[str]: + + if batch.shape[0] == 0: + logger.warning("ServiceEstimator received empty batch") + return [] + return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)] diff --git a/test/unit_tests/service_estimator_test.py b/test/unit_tests/service_estimator_test.py index ed309ed..aee2727 100644 --- a/test/unit_tests/service_estimator_test.py +++ b/test/unit_tests/service_estimator_test.py @@ -3,8 +3,9 @@ import logging import numpy as np import pytest -from image_prediction.estimator.estimators.mock import EstimatorMock from image_prediction.estimator.adapter.patch import EstimatorAdapterPatch +from image_prediction.estimator.estimators.keras import KerasEstimator +from image_prediction.estimator.estimators.mock import EstimatorMock, DummyEstimator from image_prediction.service_estimator.service_estimator import ServiceEstimator from image_prediction.utils import get_logger @@ -13,9 +14,33 @@ logger.setLevel(logging.DEBUG) @pytest.fixture(scope="session") -def estimator(estimator_type): +def input_size(): + return 10, 15 + + +@pytest.fixture(scope="session") +def keras_model(input_size): + + import os + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + + from tensorflow import keras + + inputs = keras.Input(shape=input_size) + dense = keras.layers.Dense(64, activation="relu") + outputs = keras.layers.Dense(10)(dense(inputs)) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile() + + return model + + +@pytest.fixture(scope="session") +def estimator(estimator_type, keras_model): if estimator_type == "mock": - return EstimatorMock() + return EstimatorMock(DummyEstimator()) + if estimator_type == "keras": + return KerasEstimator(keras_model) @pytest.fixture(scope="session") @@ -26,8 +51,8 @@ def estimator_adapter(output_batch, estimator): @pytest.fixture(scope="session") -def input_batch(batch_size, classes): - return np.random.normal(size=(batch_size, 10, 15)) +def input_batch(batch_size, classes, input_size): + return np.random.normal(size=(batch_size, *input_size)) @pytest.fixture(scope="session") @@ -54,7 +79,7 @@ def service_estimator(estimator_adapter, classes): return ServiceEstimator(estimator_adapter, classes) -@pytest.mark.parametrize("estimator_type", ["mock"], scope="session") +@pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session") @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session") def test_predict(service_estimator, input_batch, expected_predictions): predictions = service_estimator.predict(input_batch)