From 981d7816a0742aac2b2d3481e1cf4d2baffbda49 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Fri, 25 Mar 2022 17:58:34 +0100 Subject: [PATCH] refactoring: replaced estimator adapter with monkeypatch --- image_prediction/exceptions.py | 2 + test/unit_tests/conftest.py | 98 +++++++++++++++++++++++ test/unit_tests/service_estimator_test.py | 73 ----------------- 3 files changed, 100 insertions(+), 73 deletions(-) create mode 100644 image_prediction/exceptions.py create mode 100644 test/unit_tests/conftest.py diff --git a/image_prediction/exceptions.py b/image_prediction/exceptions.py new file mode 100644 index 0000000..290b600 --- /dev/null +++ b/image_prediction/exceptions.py @@ -0,0 +1,2 @@ +class UnknownEstimatorAdapter(ValueError): + pass \ No newline at end of file diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py new file mode 100644 index 0000000..828f117 --- /dev/null +++ b/test/unit_tests/conftest.py @@ -0,0 +1,98 @@ +import numpy as np +import pytest +from PIL import Image + +from image_prediction.estimator.estimators.keras import KerasEstimator +from image_prediction.estimator.estimators.mock import EstimatorMock, DummyEstimator +from image_prediction.exceptions import UnknownEstimatorAdapter +from image_prediction.predictor.predictor import Predictor +from image_prediction.service_estimator.service_estimator import ServiceEstimator + + +@pytest.fixture +def predictor(service_estimator): + return Predictor(service_estimator) + + +@pytest.fixture +def service_estimator(estimator, classes): + service_estimator = ServiceEstimator(estimator, classes) + return service_estimator + + +@pytest.fixture +def estimator(estimator_type, keras_model, output_batch, monkeypatch): + if estimator_type == "mock": + estimator = EstimatorMock(DummyEstimator()) + elif estimator_type == "keras": + estimator = KerasEstimator(keras_model) + else: + raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.") + + def mock_predict(batch): + _predict(batch) + return output_batch + + _predict = estimator.predict + monkeypatch.setattr(estimator, "predict", mock_predict) + + return estimator + + +@pytest.fixture +def keras_model(input_size): + import warnings + warnings.filterwarnings("ignore", category=DeprecationWarning) + + import os + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + + from tensorflow import keras + + inputs = keras.Input(shape=input_size) + dense = keras.layers.Dense(64, activation="relu") + outputs = keras.layers.Dense(10)(dense(inputs)) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile() + + return model + + +@pytest.fixture +def images(input_batch): + return list(map(array_to_image, input_batch)) + + +@pytest.fixture +def input_batch(batch_size, input_size): + return np.random.random_sample(size=(batch_size, *input_size)) + + +def array_to_image(array): + assert np.all(array <= 1) + assert np.all(array >= 0) + return Image.fromarray(np.uint8(array * 255)) + + +@pytest.fixture +def input_size(width=10, height=15): + return width, height + + +@pytest.fixture +def expected_predictions(output_batch, classes): + return map_labels(output_batch, classes) + + +@pytest.fixture +def output_batch(batch_size, classes): + return np.random.randint(low=0, high=len(classes), size=batch_size) + + +@pytest.fixture +def classes(): + return ["A", "B", "C"] + + +def map_labels(numeric_labels, classes): + return [classes[nl] for nl in numeric_labels] diff --git a/test/unit_tests/service_estimator_test.py b/test/unit_tests/service_estimator_test.py index 7503bd5..f561901 100644 --- a/test/unit_tests/service_estimator_test.py +++ b/test/unit_tests/service_estimator_test.py @@ -1,86 +1,13 @@ import logging -import numpy as np import pytest -from image_prediction.estimator.adapter.patch import EstimatorAdapterPatch -from image_prediction.estimator.estimators.keras import KerasEstimator -from image_prediction.estimator.estimators.mock import EstimatorMock, DummyEstimator -from image_prediction.service_estimator.service_estimator import ServiceEstimator from image_prediction.utils import get_logger logger = get_logger() logger.setLevel(logging.DEBUG) -@pytest.fixture(scope="session") -def input_size(): - return 10, 15 - - -@pytest.fixture(scope="session") -def keras_model(input_size): - import warnings - warnings.filterwarnings("ignore", category=DeprecationWarning) - - import os - os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' - - from tensorflow import keras - - inputs = keras.Input(shape=input_size) - dense = keras.layers.Dense(64, activation="relu") - outputs = keras.layers.Dense(10)(dense(inputs)) - model = keras.Model(inputs=inputs, outputs=outputs) - model.compile() - - return model - - -@pytest.fixture(scope="session") -def estimator(estimator_type, keras_model): - if estimator_type == "mock": - return EstimatorMock(DummyEstimator()) - if estimator_type == "keras": - return KerasEstimator(keras_model) - - -@pytest.fixture(scope="session") -def estimator_adapter(output_batch, estimator): - estimator_adapter = EstimatorAdapterPatch(estimator) - estimator_adapter.output_batch = output_batch - return estimator_adapter - - -@pytest.fixture(scope="session") -def input_batch(batch_size, classes, input_size): - return np.random.normal(size=(batch_size, *input_size)) - - -@pytest.fixture(scope="session") -def output_batch(batch_size, classes): - return np.random.randint(low=0, high=len(classes), size=batch_size) - - -@pytest.fixture(scope="session") -def expected_predictions(output_batch, classes): - return map_labels(output_batch, classes) - - -@pytest.fixture(scope="session") -def classes(): - return ["A", "B", "C"] - - -def map_labels(numeric_labels, classes): - return [classes[nl] for nl in numeric_labels] - - -@pytest.fixture(scope="session") -def service_estimator(estimator_adapter, classes): - return ServiceEstimator(estimator_adapter, classes) - - @pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session") @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session") def test_predict(service_estimator, input_batch, expected_predictions):