import logging import numpy as np import pytest from image_prediction.estimator.adapter.patch import EstimatorAdapterPatch from image_prediction.estimator.estimators.keras import KerasEstimator from image_prediction.estimator.estimators.mock import EstimatorMock, DummyEstimator from image_prediction.service_estimator.service_estimator import ServiceEstimator from image_prediction.utils import get_logger logger = get_logger() logger.setLevel(logging.DEBUG) @pytest.fixture(scope="session") def input_size(): return 10, 15 @pytest.fixture(scope="session") def keras_model(input_size): import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' from tensorflow import keras inputs = keras.Input(shape=input_size) dense = keras.layers.Dense(64, activation="relu") outputs = keras.layers.Dense(10)(dense(inputs)) model = keras.Model(inputs=inputs, outputs=outputs) model.compile() return model @pytest.fixture(scope="session") def estimator(estimator_type, keras_model): if estimator_type == "mock": return EstimatorMock(DummyEstimator()) if estimator_type == "keras": return KerasEstimator(keras_model) @pytest.fixture(scope="session") def estimator_adapter(output_batch, estimator): estimator_adapter = EstimatorAdapterPatch(estimator) estimator_adapter.output_batch = output_batch return estimator_adapter @pytest.fixture(scope="session") def input_batch(batch_size, classes, input_size): return np.random.normal(size=(batch_size, *input_size)) @pytest.fixture(scope="session") def output_batch(batch_size, classes): return np.random.randint(low=0, high=len(classes), size=batch_size) @pytest.fixture(scope="session") def expected_predictions(output_batch, classes): return map_labels(output_batch, classes) @pytest.fixture(scope="session") def classes(): return ["A", "B", "C"] def map_labels(numeric_labels, classes): return [classes[nl] for nl in numeric_labels] @pytest.fixture(scope="session") def service_estimator(estimator_adapter, classes): return ServiceEstimator(estimator_adapter, classes) @pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session") @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session") def test_predict(service_estimator, input_batch, expected_predictions): predictions = service_estimator.predict(input_batch) assert predictions == expected_predictions