128 lines
3.7 KiB
Python
128 lines
3.7 KiB
Python
import pytest
|
|
|
|
from image_prediction.classifier.classifier import Classifier
|
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
|
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
|
from image_prediction.exceptions import UnknownImageExtractor, UnknownEstimatorAdapter
|
|
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
|
from image_prediction.redai_adapter.model import PredictionModelHandle
|
|
|
|
|
|
@pytest.fixture
|
|
def estimator_mock():
|
|
class EstimatorMock:
|
|
@staticmethod
|
|
def predict(batch):
|
|
return [None for _ in batch]
|
|
|
|
@staticmethod
|
|
def predict_proba(batch):
|
|
return [None for _ in batch]
|
|
|
|
def __call__(self, batch):
|
|
return self.predict(batch)
|
|
|
|
return EstimatorMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def image_extractor(extractor_type):
|
|
if extractor_type == "mock":
|
|
return ImageExtractorMock()
|
|
elif extractor_type == "parsable_pdf":
|
|
return ParsablePDFImageExtractor()
|
|
elif extractor_type == "default":
|
|
return None
|
|
else:
|
|
raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.")
|
|
|
|
|
|
@pytest.fixture
|
|
def image_classifier(classifier, monkeypatch, batch_of_expected_string_labels):
|
|
return ImageClassifier(classifier, preprocessor=BasicPreprocessor())
|
|
|
|
|
|
@pytest.fixture
|
|
def classifier(estimator_adapter, label_mapper):
|
|
classifier = Classifier(estimator_adapter, label_mapper)
|
|
return classifier
|
|
|
|
|
|
@pytest.fixture
|
|
def estimator_adapter(
|
|
estimator_type, estimator_mock, keras_model, model_handle_mock, output_batch_generator, monkeypatch
|
|
):
|
|
if estimator_type == "mock":
|
|
estimator_adapter = EstimatorAdapter(estimator_mock)
|
|
elif estimator_type == "keras":
|
|
estimator_adapter = EstimatorAdapter(keras_model)
|
|
elif estimator_type == "redai":
|
|
estimator_adapter = EstimatorAdapter(PredictionModelHandle(model_handle_mock))
|
|
else:
|
|
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
|
|
|
|
def mock_predict(batch):
|
|
# Run real predict function to test for mechanical issues, but return externally defined
|
|
# predictions to test the callers of the estimator adapter against the expected predictions
|
|
return [next(output_batch_generator) for _ in _predict(batch)]
|
|
|
|
_predict = estimator_adapter.predict
|
|
monkeypatch.setattr(estimator_adapter, "predict", mock_predict)
|
|
|
|
return estimator_adapter
|
|
|
|
|
|
@pytest.fixture
|
|
def keras_model(input_size):
|
|
import os
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
|
import tensorflow as tf
|
|
|
|
tf.keras.backend.set_image_data_format("channels_last")
|
|
|
|
inputs = tf.keras.Input(shape=input_size)
|
|
conv = tf.keras.layers.Conv2D(3, 3)
|
|
dense = tf.keras.layers.Dense(10)
|
|
|
|
outputs = tf.keras.layers.Dense(10)(dense(conv(inputs)))
|
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
|
model.compile()
|
|
|
|
return model
|
|
|
|
|
|
@pytest.fixture
|
|
def model():
|
|
class Model:
|
|
@staticmethod
|
|
def predict(*args):
|
|
return True
|
|
|
|
@staticmethod
|
|
def predict_proba(*args):
|
|
return True
|
|
|
|
return Model()
|
|
|
|
|
|
@pytest.fixture
|
|
def model_handle_mock(estimator_mock):
|
|
class ModelHandleMock:
|
|
def __init__(self):
|
|
self.model = estimator_mock
|
|
|
|
def prep_images(self, batch):
|
|
return [None for _ in batch]
|
|
|
|
def predict(self, batch):
|
|
return [None for _ in batch]
|
|
|
|
def predict_proba(self, batch):
|
|
return [None for _ in batch]
|
|
|
|
return ModelHandleMock()
|