import numpy as np import pytest from PIL import Image from image_prediction.estimator.estimators.keras import KerasEstimator from image_prediction.estimator.estimators.mock import EstimatorMock, DummyEstimator from image_prediction.exceptions import UnknownEstimatorAdapter from image_prediction.predictor.predictor import Predictor from image_prediction.service_estimator.service_estimator import ServiceEstimator @pytest.fixture def predictor(service_estimator): return Predictor(service_estimator) @pytest.fixture def service_estimator(estimator, classes): service_estimator = ServiceEstimator(estimator, classes) return service_estimator @pytest.fixture def estimator(estimator_type, keras_model, output_batch, monkeypatch): if estimator_type == "mock": estimator = EstimatorMock(DummyEstimator()) elif estimator_type == "keras": estimator = KerasEstimator(keras_model) else: raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.") def mock_predict(batch): _predict(batch) return output_batch _predict = estimator.predict monkeypatch.setattr(estimator, "predict", mock_predict) return estimator @pytest.fixture def keras_model(input_size): import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' from tensorflow import keras inputs = keras.Input(shape=input_size) dense = keras.layers.Dense(64, activation="relu") outputs = keras.layers.Dense(10)(dense(inputs)) model = keras.Model(inputs=inputs, outputs=outputs) model.compile() return model @pytest.fixture def images(input_batch): return list(map(array_to_image, input_batch)) @pytest.fixture def input_batch(batch_size, input_size): return np.random.random_sample(size=(batch_size, *input_size)) def array_to_image(array): assert np.all(array <= 1) assert np.all(array >= 0) return Image.fromarray(np.uint8(array * 255)) @pytest.fixture def input_size(width=10, height=15): return width, height @pytest.fixture def expected_predictions(output_batch, classes): return map_labels(output_batch, classes) @pytest.fixture def output_batch(batch_size, classes): return np.random.randint(low=0, high=len(classes), size=batch_size) @pytest.fixture def classes(): return ["A", "B", "C"] def map_labels(numeric_labels, classes): return [classes[nl] for nl in numeric_labels]