Matthias Bisping 37ee086b5d applied black
2022-04-05 17:55:38 +02:00

488 lines
14 KiB
Python

import json
import logging
import os
import random
import string
import tempfile
from functools import partial
from itertools import starmap
from operator import itemgetter
import fpdf
import numpy as np
import pytest
from PIL import Image
from funcy import rcompose
from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.exceptions import (
UnknownEstimatorAdapter,
UnknownImageExtractor,
UnknownDatabaseType,
UnknownLabelFormat,
)
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.info import Info
from image_prediction.label_mapper.mappers.numeric import IndexMapper
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper, ProbabilityMapperKeys
from image_prediction.locations import TEST_DATA_DIR
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
from image_prediction.pipeline import load_pipeline
from image_prediction.redai_adapter.mlflow import MlflowModelReader
from image_prediction.redai_adapter.model import PredictionModelHandle
from image_prediction.utils import get_logger
@pytest.fixture(autouse=True)
def mute_logger():
logger = get_logger()
level = logger.level
logger.setLevel(logging.CRITICAL + 1)
yield
logger.setLevel(level)
@pytest.fixture
def image_extractor(extractor_type):
if extractor_type == "mock":
return ImageExtractorMock()
elif extractor_type == "parsable_pdf":
return ParsablePDFImageExtractor()
elif extractor_type == "default":
return None
else:
raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.")
@pytest.fixture
def image_classifier(classifier, monkeypatch, batch_of_expected_string_labels):
return ImageClassifier(classifier, preprocessor=BasicPreprocessor())
@pytest.fixture
def classifier(estimator_adapter, label_mapper):
classifier = Classifier(estimator_adapter, label_mapper)
return classifier
@pytest.fixture
def estimator_mock():
class EstimatorMock:
@staticmethod
def predict(batch):
return [None for _ in batch]
@staticmethod
def predict_proba(batch):
return [None for _ in batch]
def __call__(self, batch):
return self.predict(batch)
return EstimatorMock()
@pytest.fixture
def label_mapper(label_format, classes):
if label_format == "index":
return IndexMapper(classes)
elif label_format == "probability":
return ProbabilityMapper(classes)
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture(params=["index"])
def label_format(request):
return request.param
@pytest.fixture
def expected_predictions_mapped(
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
):
if label_format == "index":
return batch_of_expected_string_labels
elif label_format == "probability":
return batch_of_expected_label_to_probability_mappings
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture
def expected_predictions(label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays):
if label_format == "index":
return batch_of_expected_numeric_labels
elif label_format == "probability":
return batch_of_expected_probability_arrays
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture
def estimator_adapter(
estimator_type, estimator_mock, keras_model, model_handle_mock, output_batch_generator, monkeypatch
):
if estimator_type == "mock":
estimator_adapter = EstimatorAdapter(estimator_mock)
elif estimator_type == "keras":
estimator_adapter = EstimatorAdapter(keras_model)
elif estimator_type == "redai":
estimator_adapter = EstimatorAdapter(PredictionModelHandle(model_handle_mock))
else:
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
def mock_predict(batch):
# Run real predict function to test for mechanical issues, but return externally defined
# predictions to test the callers of the estimator adapter against the expected predictions
return [next(output_batch_generator) for _ in _predict(batch)]
_predict = estimator_adapter.predict
monkeypatch.setattr(estimator_adapter, "predict", mock_predict)
return estimator_adapter
@pytest.fixture
def keras_model(input_size):
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
tf.keras.backend.set_image_data_format("channels_last")
inputs = tf.keras.Input(shape=input_size)
conv = tf.keras.layers.Conv2D(3, 3)
dense = tf.keras.layers.Dense(10)
outputs = tf.keras.layers.Dense(10)(dense(conv(inputs)))
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.compile()
return model
@pytest.fixture
def images(input_batch):
return list(map(array_to_image, input_batch))
@pytest.fixture
def input_batch(batch_size, input_size):
return np.random.random_sample(size=(batch_size, *input_size))
@pytest.fixture(params=[0, 1, 2, 16, 32])
def batch_size(request):
return request.param
@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}])
def input_size(request):
return itemgetter("width", "height", "depth")(request.param)
def array_to_image(array):
assert np.all(array <= 1)
assert np.all(array >= 0)
return Image.fromarray(np.uint8(array * 255), mode="RGB")
@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}])
def input_size(request):
return itemgetter("width", "height", "depth")(request.param)
@pytest.fixture
def batch_of_expected_string_labels(batch_of_expected_numeric_labels, classes):
return map_labels(batch_of_expected_numeric_labels, classes)
@pytest.fixture
def batch_of_expected_numeric_labels(batch_size, classes):
return random.choices(range(len(classes)), k=batch_size)
@pytest.fixture
def batch_of_expected_label_to_probability_mappings(batch_of_expected_probability_arrays, classes):
def map_probabilities(probabilities):
lbl2prob = dict(sorted(zip(classes, map(rounder, probabilities)), key=itemgetter(1), reverse=True))
most_likely = [*lbl2prob][0]
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob}
rounder = rcompose(partial(np.round, decimals=4), float)
return list(map(map_probabilities, batch_of_expected_probability_arrays))
@pytest.fixture
def batch_of_expected_probability_arrays(batch_size, classes):
return [np.random.uniform(size=len(classes)) for _ in range(batch_size)]
@pytest.fixture
def output_batch_generator(expected_predictions):
return iter(expected_predictions)
@pytest.fixture
def classes():
return ["A", "B", "C"]
def map_labels(numeric_labels, classes):
return [classes[nl] for nl in numeric_labels]
@pytest.fixture
def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata):
return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
@pytest.fixture
def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted):
return [
{"classification": epm, **mdt}
for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted)
]
@pytest.fixture
def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
return [{k.value: v for k, v in epm.items()} for epm in expected_predictions_mapped]
@pytest.fixture
def metadata(images, info_label_map):
page_idx = 0
def current_page_idx():
nonlocal page_idx
page_idx += random.randint(0, 3)
return min(page_idx, len(images) - 1)
def build_image_metadata(image):
width, height = image.size
page_width = 595
page_height = 842
x1 = random.randint(0, page_width - width)
x2 = x1 + width
y1 = random.randint(0, page_height - height)
y2 = y1 + height
metadata = {
info_label_map.PAGE_WIDTH: page_width,
info_label_map.PAGE_HEIGHT: page_height,
info_label_map.PAGE_IDX: current_page_idx(),
info_label_map.WIDTH: width,
info_label_map.HEIGHT: height,
info_label_map.X1: x1,
info_label_map.X2: x2,
info_label_map.Y1: y1,
info_label_map.Y2: y2,
}
return metadata
return list(map(build_image_metadata, images))
@pytest.fixture
def info_label_map():
return Info
@pytest.fixture
def metadata_formatted(metadata):
def format_metadata(metadata):
return {key.value: val for key, val in metadata.items()}
return list(map(format_metadata, metadata))
@pytest.fixture
def image_metadata_pairs(images, metadata):
return list(starmap(ImageMetadataPair, zip(images, metadata)))
@pytest.fixture
def pdf(image_metadata_pairs):
pdf = fpdf.FPDF(unit="pt")
for pair in image_metadata_pairs:
add_image(pdf, pair)
return pdf_stream(pdf)
def add_image(pdf, image_metadata_pair):
while fewer_pages_then_required(image_metadata_pair.metadata[Info.PAGE_IDX], pdf):
pdf.add_page()
add_image_to_last_page(pdf, image_metadata_pair)
def fewer_pages_then_required(page_idx, pdf):
return page_idx > pdf.page - 1
def pdf_stream(pdf: fpdf.fpdf.FPDF):
return pdf.output(dest="S").encode("latin1")
def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair):
image, metadata = image_metadata_pair
x, y, w, h = itemgetter(Info.X1, Info.Y1, Info.WIDTH, Info.HEIGHT)(metadata)
with tempfile.NamedTemporaryFile(suffix=".png") as temp_image:
image.save(temp_image.name)
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type="png")
@pytest.fixture
def model():
class Model:
@staticmethod
def predict(*args):
return True
@staticmethod
def predict_proba(*args):
return True
return Model()
@pytest.fixture
def model_database_record_identifier():
return "".join(random.sample(string.ascii_letters, k=10))
@pytest.fixture
def model_database_record(model, classes):
return {"model": model, "classes": classes}
@pytest.fixture
def model_database(model_database_record, model_database_record_identifier):
return {model_database_record_identifier: model_database_record}
@pytest.fixture
def database_connector(database_type, model_database, mlflow_reader):
if database_type == "mock":
return DatabaseConnectorMock(model_database)
elif database_type == "mlflow":
return MlflowConnector(mlflow_reader)
else:
raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.")
@pytest.fixture
def model_loader(database_connector):
return ModelLoader(database_connector)
@pytest.fixture
def mlflow_run_id():
from image_prediction.config import CONFIG
return CONFIG.service.run_id
@pytest.fixture
def mlruns_dir():
from image_prediction.locations import MLRUNS_DIR
return MLRUNS_DIR
@pytest.fixture
def mlflow_reader(mlruns_dir):
return MlflowModelReader(mlruns_dir)
@pytest.fixture
def model_handle_mock(estimator_mock):
class ModelHandleMock:
def __init__(self):
self.model = estimator_mock
def prep_images(self, batch):
return [None for _ in batch]
def predict(self, batch):
return [None for _ in batch]
def predict_proba(self, batch):
return [None for _ in batch]
return ModelHandleMock()
@pytest.fixture
def real_pdf():
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f:
yield f.read()
@pytest.fixture
def real_expected_service_response():
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
yield json.load(f)
@pytest.fixture
def pipeline():
pipeline = load_pipeline(verbose=False)
return pipeline
def transform_equal(a, b):
return (list(a) if isinstance(a, map) else a) == b
def get_base_position_metadata(width, height, page_width, page_height):
return {
Info.WIDTH: width,
Info.HEIGHT: height,
Info.PAGE_IDX: 0,
Info.PAGE_WIDTH: page_width,
Info.PAGE_HEIGHT: page_height,
}
@pytest.fixture(params=[33, 100])
def height(request):
return request.param
@pytest.fixture(params=[10, 31])
def width(request):
return request.param
@pytest.fixture(params=[220, 30])
def page_height(request):
return request.param
@pytest.fixture(params=[100, 310])
def page_width(request):
return request.param
def random_single_color_image_from_metadata(metadata):
image = Image.new(
"RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=tuple(map(int, np.random.uniform(size=3) * 255))
)
return image