import random import numpy as np import pytest from PIL import Image from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter from image_prediction.estimator.adapter.adapters.mock import DummyEstimator, EstimatorAdapterMock from image_prediction.exceptions import UnknownEstimatorAdapter @pytest.fixture def predictor(classifier, monkeypatch, expected_predictions): return ImageClassifier(classifier) @pytest.fixture def classifier(estimator_adapter, classes): classifier = Classifier(estimator_adapter, classes) return classifier @pytest.fixture def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch): if estimator_type == "mock": estimator_adapter = EstimatorAdapterMock(DummyEstimator()) elif estimator_type == "keras": estimator_adapter = KerasEstimatorAdapter(keras_model) else: raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.") def mock_predict(batch): # Run real predict function to test for mechanical issues, but return externally defined # predictions to test the callers of the estimator adapter against the expected predictions return [next(output_batch_generator) for _ in _predict(batch)] _predict = estimator_adapter.predict monkeypatch.setattr(estimator_adapter, "predict", mock_predict) return estimator_adapter @pytest.fixture def keras_model(input_size): import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) import os os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf tf.keras.backend.set_image_data_format("channels_last") inputs = tf.keras.Input(shape=input_size) conv = tf.keras.layers.Conv2D(3, 3) dense = tf.keras.layers.Dense(10) outputs = tf.keras.layers.Dense(10)(dense(conv(inputs))) model = tf.keras.Model(inputs=inputs, outputs=outputs) model.compile() return model @pytest.fixture def images(input_batch): return list(map(array_to_image, input_batch)) @pytest.fixture def input_batch(batch_size, input_size): return np.random.random_sample(size=(batch_size, *input_size)) def array_to_image(array): assert np.all(array <= 1) assert np.all(array >= 0) return Image.fromarray(np.uint8(array * 255), mode="RGB") @pytest.fixture def input_size(depth=3, width=10, height=15): return width, height, depth @pytest.fixture def expected_predictions(output_batch, classes): return map_labels(output_batch, classes) @pytest.fixture def output_batch(input_batch, classes): return random.choices(range(len(classes)), k=len(input_batch)) @pytest.fixture def output_batch_generator(output_batch): return iter(output_batch) @pytest.fixture def classes(): return ["A", "B", "C"] def map_labels(numeric_labels, classes): return [classes[nl] for nl in numeric_labels]