Matthias Bisping 9e29d2e5f9 applied black
2022-04-14 19:15:19 +02:00

128 lines
3.7 KiB
Python

import pytest
from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.exceptions import UnknownImageExtractor, UnknownEstimatorAdapter
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.redai_adapter.model import PredictionModelHandle
@pytest.fixture
def estimator_mock():
class EstimatorMock:
@staticmethod
def predict(batch):
return [None for _ in batch]
@staticmethod
def predict_proba(batch):
return [None for _ in batch]
def __call__(self, batch):
return self.predict(batch)
return EstimatorMock()
@pytest.fixture
def image_extractor(extractor_type):
if extractor_type == "mock":
return ImageExtractorMock()
elif extractor_type == "parsable_pdf":
return ParsablePDFImageExtractor()
elif extractor_type == "default":
return None
else:
raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.")
@pytest.fixture
def image_classifier(classifier, monkeypatch, batch_of_expected_string_labels):
return ImageClassifier(classifier, preprocessor=BasicPreprocessor())
@pytest.fixture
def classifier(estimator_adapter, label_mapper):
classifier = Classifier(estimator_adapter, label_mapper)
return classifier
@pytest.fixture
def estimator_adapter(
estimator_type, estimator_mock, keras_model, model_handle_mock, output_batch_generator, monkeypatch
):
if estimator_type == "mock":
estimator_adapter = EstimatorAdapter(estimator_mock)
elif estimator_type == "keras":
estimator_adapter = EstimatorAdapter(keras_model)
elif estimator_type == "redai":
estimator_adapter = EstimatorAdapter(PredictionModelHandle(model_handle_mock))
else:
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
def mock_predict(batch):
# Run real predict function to test for mechanical issues, but return externally defined
# predictions to test the callers of the estimator adapter against the expected predictions
return [next(output_batch_generator) for _ in _predict(batch)]
_predict = estimator_adapter.predict
monkeypatch.setattr(estimator_adapter, "predict", mock_predict)
return estimator_adapter
@pytest.fixture
def keras_model(input_size):
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
tf.keras.backend.set_image_data_format("channels_last")
inputs = tf.keras.Input(shape=input_size)
conv = tf.keras.layers.Conv2D(3, 3)
dense = tf.keras.layers.Dense(10)
outputs = tf.keras.layers.Dense(10)(dense(conv(inputs)))
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.compile()
return model
@pytest.fixture
def model():
class Model:
@staticmethod
def predict(*args):
return True
@staticmethod
def predict_proba(*args):
return True
return Model()
@pytest.fixture
def model_handle_mock(estimator_mock):
class ModelHandleMock:
def __init__(self):
self.model = estimator_mock
def prep_images(self, batch):
return [None for _ in batch]
def predict(self, batch):
return [None for _ in batch]
def predict_proba(self, batch):
return [None for _ in batch]
return ModelHandleMock()