refactoring: replaced estimator adapter with monkeypatch
This commit is contained in:
parent
2e36a9d46d
commit
981d7816a0
2
image_prediction/exceptions.py
Normal file
2
image_prediction/exceptions.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
class UnknownEstimatorAdapter(ValueError):
|
||||||
|
pass
|
||||||
98
test/unit_tests/conftest.py
Normal file
98
test/unit_tests/conftest.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from image_prediction.estimator.estimators.keras import KerasEstimator
|
||||||
|
from image_prediction.estimator.estimators.mock import EstimatorMock, DummyEstimator
|
||||||
|
from image_prediction.exceptions import UnknownEstimatorAdapter
|
||||||
|
from image_prediction.predictor.predictor import Predictor
|
||||||
|
from image_prediction.service_estimator.service_estimator import ServiceEstimator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def predictor(service_estimator):
|
||||||
|
return Predictor(service_estimator)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service_estimator(estimator, classes):
|
||||||
|
service_estimator = ServiceEstimator(estimator, classes)
|
||||||
|
return service_estimator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def estimator(estimator_type, keras_model, output_batch, monkeypatch):
|
||||||
|
if estimator_type == "mock":
|
||||||
|
estimator = EstimatorMock(DummyEstimator())
|
||||||
|
elif estimator_type == "keras":
|
||||||
|
estimator = KerasEstimator(keras_model)
|
||||||
|
else:
|
||||||
|
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
|
||||||
|
|
||||||
|
def mock_predict(batch):
|
||||||
|
_predict(batch)
|
||||||
|
return output_batch
|
||||||
|
|
||||||
|
_predict = estimator.predict
|
||||||
|
monkeypatch.setattr(estimator, "predict", mock_predict)
|
||||||
|
|
||||||
|
return estimator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def keras_model(input_size):
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
|
||||||
|
from tensorflow import keras
|
||||||
|
|
||||||
|
inputs = keras.Input(shape=input_size)
|
||||||
|
dense = keras.layers.Dense(64, activation="relu")
|
||||||
|
outputs = keras.layers.Dense(10)(dense(inputs))
|
||||||
|
model = keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
model.compile()
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def images(input_batch):
|
||||||
|
return list(map(array_to_image, input_batch))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def input_batch(batch_size, input_size):
|
||||||
|
return np.random.random_sample(size=(batch_size, *input_size))
|
||||||
|
|
||||||
|
|
||||||
|
def array_to_image(array):
|
||||||
|
assert np.all(array <= 1)
|
||||||
|
assert np.all(array >= 0)
|
||||||
|
return Image.fromarray(np.uint8(array * 255))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def input_size(width=10, height=15):
|
||||||
|
return width, height
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_predictions(output_batch, classes):
|
||||||
|
return map_labels(output_batch, classes)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def output_batch(batch_size, classes):
|
||||||
|
return np.random.randint(low=0, high=len(classes), size=batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def classes():
|
||||||
|
return ["A", "B", "C"]
|
||||||
|
|
||||||
|
|
||||||
|
def map_labels(numeric_labels, classes):
|
||||||
|
return [classes[nl] for nl in numeric_labels]
|
||||||
@ -1,86 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from image_prediction.estimator.adapter.patch import EstimatorAdapterPatch
|
|
||||||
from image_prediction.estimator.estimators.keras import KerasEstimator
|
|
||||||
from image_prediction.estimator.estimators.mock import EstimatorMock, DummyEstimator
|
|
||||||
from image_prediction.service_estimator.service_estimator import ServiceEstimator
|
|
||||||
from image_prediction.utils import get_logger
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def input_size():
|
|
||||||
return 10, 15
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def keras_model(input_size):
|
|
||||||
import warnings
|
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
||||||
|
|
||||||
from tensorflow import keras
|
|
||||||
|
|
||||||
inputs = keras.Input(shape=input_size)
|
|
||||||
dense = keras.layers.Dense(64, activation="relu")
|
|
||||||
outputs = keras.layers.Dense(10)(dense(inputs))
|
|
||||||
model = keras.Model(inputs=inputs, outputs=outputs)
|
|
||||||
model.compile()
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def estimator(estimator_type, keras_model):
|
|
||||||
if estimator_type == "mock":
|
|
||||||
return EstimatorMock(DummyEstimator())
|
|
||||||
if estimator_type == "keras":
|
|
||||||
return KerasEstimator(keras_model)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def estimator_adapter(output_batch, estimator):
|
|
||||||
estimator_adapter = EstimatorAdapterPatch(estimator)
|
|
||||||
estimator_adapter.output_batch = output_batch
|
|
||||||
return estimator_adapter
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def input_batch(batch_size, classes, input_size):
|
|
||||||
return np.random.normal(size=(batch_size, *input_size))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def output_batch(batch_size, classes):
|
|
||||||
return np.random.randint(low=0, high=len(classes), size=batch_size)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def expected_predictions(output_batch, classes):
|
|
||||||
return map_labels(output_batch, classes)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def classes():
|
|
||||||
return ["A", "B", "C"]
|
|
||||||
|
|
||||||
|
|
||||||
def map_labels(numeric_labels, classes):
|
|
||||||
return [classes[nl] for nl in numeric_labels]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def service_estimator(estimator_adapter, classes):
|
|
||||||
return ServiceEstimator(estimator_adapter, classes)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
|
@pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
|
||||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
||||||
def test_predict(service_estimator, input_batch, expected_predictions):
|
def test_predict(service_estimator, input_batch, expected_predictions):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user