image-classification-service/test/unit_tests/service_estimator_test.py
2022-03-25 14:56:47 +01:00

89 lines
2.5 KiB
Python

import logging
import numpy as np
import pytest
from image_prediction.estimator.adapter.patch import EstimatorAdapterPatch
from image_prediction.estimator.estimators.keras import KerasEstimator
from image_prediction.estimator.estimators.mock import EstimatorMock, DummyEstimator
from image_prediction.service_estimator.service_estimator import ServiceEstimator
from image_prediction.utils import get_logger
logger = get_logger()
logger.setLevel(logging.DEBUG)
@pytest.fixture(scope="session")
def input_size():
return 10, 15
@pytest.fixture(scope="session")
def keras_model(input_size):
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow import keras
inputs = keras.Input(shape=input_size)
dense = keras.layers.Dense(64, activation="relu")
outputs = keras.layers.Dense(10)(dense(inputs))
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile()
return model
@pytest.fixture(scope="session")
def estimator(estimator_type, keras_model):
if estimator_type == "mock":
return EstimatorMock(DummyEstimator())
if estimator_type == "keras":
return KerasEstimator(keras_model)
@pytest.fixture(scope="session")
def estimator_adapter(output_batch, estimator):
estimator_adapter = EstimatorAdapterPatch(estimator)
estimator_adapter.output_batch = output_batch
return estimator_adapter
@pytest.fixture(scope="session")
def input_batch(batch_size, classes, input_size):
return np.random.normal(size=(batch_size, *input_size))
@pytest.fixture(scope="session")
def output_batch(batch_size, classes):
return np.random.randint(low=0, high=len(classes), size=batch_size)
@pytest.fixture(scope="session")
def expected_predictions(output_batch, classes):
return map_labels(output_batch, classes)
@pytest.fixture(scope="session")
def classes():
return ["A", "B", "C"]
def map_labels(numeric_labels, classes):
return [classes[nl] for nl in numeric_labels]
@pytest.fixture(scope="session")
def service_estimator(estimator_adapter, classes):
return ServiceEstimator(estimator_adapter, classes)
@pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
def test_predict(service_estimator, input_batch, expected_predictions):
predictions = service_estimator.predict(input_batch)
assert predictions == expected_predictions