Matthias Bisping ea298dacfa renaming
2022-03-26 19:27:37 +01:00

108 lines
2.7 KiB
Python

import numpy as np
import pytest
from PIL import Image
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
from image_prediction.estimator.adapter.adapters.mock import DummyEstimator, EstimatorAdapterMock
from image_prediction.estimator.estimator import Estimator
from image_prediction.exceptions import UnknownEstimatorAdapter
from image_prediction.predictor.predictor import Predictor
@pytest.fixture
def predictor(estimator):
return Predictor(estimator)
@pytest.fixture
def estimator(estimator_adapter, classes):
service_estimator = Estimator(estimator_adapter, classes)
return service_estimator
@pytest.fixture
def estimator_adapter(estimator_type, keras_model, output_batch, monkeypatch):
if estimator_type == "mock":
estimator = EstimatorAdapterMock(DummyEstimator())
elif estimator_type == "keras":
estimator = KerasEstimatorAdapter(keras_model)
else:
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
def mock_predict(batch):
_predict(batch)
return output_batch
_predict = estimator.predict
monkeypatch.setattr(estimator, "predict", mock_predict)
return estimator
@pytest.fixture
def keras_model(input_size):
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.keras.backend.set_image_data_format('channels_last')
inputs = tf.keras.Input(shape=input_size)
conv = tf.keras.layers.Conv2D(3, 3)
dense = tf.keras.layers.Dense(10, activation="relu")
outputs = tf.keras.layers.Dense(10)(dense(conv(inputs)))
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.compile()
return model
@pytest.fixture
def batch_size():
return 4
@pytest.fixture
def images(input_batch):
return list(map(array_to_image, input_batch))
@pytest.fixture
def input_batch(batch_size, input_size):
return np.random.random_sample(size=(batch_size, *input_size))
def array_to_image(array):
assert np.all(array <= 1)
assert np.all(array >= 0)
return Image.fromarray(np.uint8(array * 255), mode="RGB")
@pytest.fixture
def input_size(depth=3, width=10, height=15):
return width, height, depth
@pytest.fixture
def expected_predictions(output_batch, classes):
return map_labels(output_batch, classes)
@pytest.fixture
def output_batch(batch_size, classes):
return np.random.randint(low=0, high=len(classes), size=batch_size)
@pytest.fixture
def classes():
return ["A", "B", "C"]
def map_labels(numeric_labels, classes):
return [classes[nl] for nl in numeric_labels]