361 lines
10 KiB
Python

import random
import string
import tempfile
from itertools import starmap
from operator import itemgetter
import fpdf
import numpy as np
import pytest
from PIL import Image
from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType, \
UnknownLabelFormat
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.info import Info
from image_prediction.label_mapper.mappers.numeric import IndexMapper
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
from image_prediction.redai_adapter.mlflow import MlflowModelReader
@pytest.fixture
def image_extractor(extractor_type):
if extractor_type == "mock":
return ImageExtractorMock()
elif extractor_type == "parsable_pdf":
return ParsablePDFImageExtractor()
elif extractor_type == "default":
return None
else:
raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.")
@pytest.fixture
def image_classifier(classifier, monkeypatch, batch_of_expected_string_labels):
return ImageClassifier(classifier, preprocessor=BasicPreprocessor())
@pytest.fixture
def classifier(estimator_adapter, label_mapper):
classifier = Classifier(estimator_adapter, label_mapper)
return classifier
class EstimatorMock:
@staticmethod
def predict(batch):
return [None for _ in batch]
def __call__(self, batch):
return self.predict(batch)
@pytest.fixture
def label_mapper(label_format, classes):
if label_format == "index":
return IndexMapper(classes)
elif label_format == "probability":
return ProbabilityMapper(classes)
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture(params=["index"])
def label_format(request):
return request.param
@pytest.fixture
def expected_predictions_mapped(
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
):
if label_format == "index":
return batch_of_expected_string_labels
elif label_format == "probability":
return batch_of_expected_label_to_probability_mappings
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture
def expected_predictions(
label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays
):
if label_format == "index":
return batch_of_expected_numeric_labels
elif label_format == "probability":
return batch_of_expected_probability_arrays
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture
def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch):
if estimator_type == "mock":
estimator_adapter = EstimatorAdapter(EstimatorMock())
elif estimator_type == "keras":
estimator_adapter = EstimatorAdapter(keras_model)
else:
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
def mock_predict(batch):
# Run real predict function to test for mechanical issues, but return externally defined
# predictions to test the callers of the estimator adapter against the expected predictions
return [next(output_batch_generator) for _ in _predict(batch)]
_predict = estimator_adapter.predict
monkeypatch.setattr(estimator_adapter, "predict", mock_predict)
return estimator_adapter
@pytest.fixture
def keras_model(input_size):
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
tf.keras.backend.set_image_data_format("channels_last")
inputs = tf.keras.Input(shape=input_size)
conv = tf.keras.layers.Conv2D(3, 3)
dense = tf.keras.layers.Dense(10)
outputs = tf.keras.layers.Dense(10)(dense(conv(inputs)))
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.compile()
return model
@pytest.fixture
def images(input_batch):
return list(map(array_to_image, input_batch))
@pytest.fixture
def input_batch(batch_size, input_size):
return np.random.random_sample(size=(batch_size, *input_size))
@pytest.fixture(params=[0, 1, 2, 16, 32])
def batch_size(request):
return request.param
@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}])
def input_size(request):
return itemgetter("width", "height", "depth")(request.param)
def array_to_image(array):
assert np.all(array <= 1)
assert np.all(array >= 0)
return Image.fromarray(np.uint8(array * 255), mode="RGB")
@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}])
def input_size(request):
return itemgetter("width", "height", "depth")(request.param)
@pytest.fixture
def batch_of_expected_string_labels(batch_of_expected_numeric_labels, classes):
return map_labels(batch_of_expected_numeric_labels, classes)
@pytest.fixture
def batch_of_expected_numeric_labels(batch_size, classes):
return random.choices(range(len(classes)), k=batch_size)
@pytest.fixture
def batch_of_expected_label_to_probability_mappings(batch_of_expected_probability_arrays, classes):
def map_probabilities(probabilities):
lbl2prob = dict(sorted(zip(classes, probabilities), key=itemgetter(1), reverse=True))
most_likely = [*lbl2prob][0]
return {"label": most_likely, "probabilities": lbl2prob}
return list(map(map_probabilities, batch_of_expected_probability_arrays))
@pytest.fixture
def batch_of_expected_probability_arrays(batch_size, classes):
return [np.random.uniform(size=len(classes)) for _ in range(batch_size)]
@pytest.fixture
def output_batch_generator(expected_predictions):
return iter(expected_predictions)
@pytest.fixture
def classes():
return ["A", "B", "C"]
def map_labels(numeric_labels, classes):
return [classes[nl] for nl in numeric_labels]
@pytest.fixture
def metadata(images, info_label_map):
page_idx = 0
def current_page_idx():
nonlocal page_idx
page_idx += random.randint(0, 3)
return min(page_idx, len(images) - 1)
def build_image_metadata(image):
width, height = image.size
page_width = 595
page_height = 842
x1 = random.randint(0, page_width - width)
x2 = x1 + width
y1 = random.randint(0, page_height - height)
y2 = y1 + height
metadata = {
info_label_map.PAGE_WIDTH: page_width,
info_label_map.PAGE_HEIGHT: page_height,
info_label_map.PAGE_IDX: current_page_idx(),
info_label_map.WIDTH: width,
info_label_map.HEIGHT: height,
info_label_map.X1: x1,
info_label_map.X2: x2,
info_label_map.Y1: y1,
info_label_map.Y2: y2,
}
return metadata
return list(map(build_image_metadata, images))
@pytest.fixture
def info_label_map():
return Info
@pytest.fixture
def metadata_formatted(metadata):
def format_metadata(metadata):
return {key.value: val for key, val in metadata.items()}
return list(map(format_metadata, metadata))
@pytest.fixture
def image_metadata_pairs(images, metadata):
return list(starmap(ImageMetadataPair, zip(images, metadata)))
@pytest.fixture
def pdf(image_metadata_pairs):
pdf = fpdf.FPDF(unit="pt")
for pair in image_metadata_pairs:
add_image(pdf, pair)
return pdf_stream(pdf)
def add_image(pdf, image_metadata_pair):
while fewer_pages_then_required(image_metadata_pair.metadata[Info.PAGE_IDX], pdf):
pdf.add_page()
add_image_to_last_page(pdf, image_metadata_pair)
def fewer_pages_then_required(page_idx, pdf):
return page_idx > pdf.page - 1
def pdf_stream(pdf: fpdf.fpdf.FPDF):
return pdf.output(dest="S").encode("latin1")
def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair):
image, metadata = image_metadata_pair
x, y, w, h = itemgetter(Info.X1, Info.Y1, Info.WIDTH, Info.HEIGHT)(metadata)
with tempfile.NamedTemporaryFile(suffix=".png") as temp_image:
image.save(temp_image.name)
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type="png")
@pytest.fixture
def model():
class Model:
@staticmethod
def predict(*args):
return True
@staticmethod
def predict_proba(*args):
return True
return Model()
@pytest.fixture
def model_database_record_identifier():
return "".join(random.sample(string.ascii_letters, k=10))
@pytest.fixture
def model_database_record(model, classes):
return {"model": model, "classes": classes}
@pytest.fixture
def model_database(model_database_record, model_database_record_identifier):
return {model_database_record_identifier: model_database_record}
@pytest.fixture
def database_connector(database_type, model_database, mlflow_reader):
if database_type == "mock":
return DatabaseConnectorMock(model_database)
elif database_type == "mlflow":
return MlflowConnector(mlflow_reader)
else:
raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.")
@pytest.fixture
def model_loader(database_connector):
return ModelLoader(database_connector)
@pytest.fixture
def mlflow_run_id():
from image_prediction.config import CONFIG
return CONFIG.service.run_id
@pytest.fixture
def mlruns_dir():
from image_prediction.locations import MLRUNS_DIR
return MLRUNS_DIR
@pytest.fixture
def mlflow_reader(mlruns_dir):
return MlflowModelReader(mlruns_dir)