516 lines
14 KiB
Python
516 lines
14 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import string
|
|
import tempfile
|
|
from functools import partial
|
|
from itertools import starmap
|
|
from operator import itemgetter
|
|
|
|
import fpdf
|
|
import numpy as np
|
|
import pytest
|
|
from PIL import Image
|
|
from funcy import rcompose, merge
|
|
|
|
from image_prediction.classifier.classifier import Classifier
|
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
|
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
|
from image_prediction.exceptions import (
|
|
UnknownEstimatorAdapter,
|
|
UnknownImageExtractor,
|
|
UnknownDatabaseType,
|
|
UnknownLabelFormat,
|
|
)
|
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
|
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
|
from image_prediction.info import Info
|
|
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
|
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper, ProbabilityMapperKeys
|
|
from image_prediction.locations import TEST_DATA_DIR
|
|
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
|
from image_prediction.model_loader.loader import ModelLoader
|
|
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
|
from image_prediction.pipeline import load_pipeline
|
|
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
|
from image_prediction.redai_adapter.model import PredictionModelHandle
|
|
from image_prediction.utils import get_logger
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mute_logger():
|
|
logger = get_logger()
|
|
level = logger.level
|
|
logger.setLevel(logging.CRITICAL + 1)
|
|
yield
|
|
logger.setLevel(level)
|
|
|
|
|
|
@pytest.fixture
|
|
def image_extractor(extractor_type):
|
|
if extractor_type == "mock":
|
|
return ImageExtractorMock()
|
|
elif extractor_type == "parsable_pdf":
|
|
return ParsablePDFImageExtractor()
|
|
elif extractor_type == "default":
|
|
return None
|
|
else:
|
|
raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.")
|
|
|
|
|
|
@pytest.fixture
|
|
def image_classifier(classifier, monkeypatch, batch_of_expected_string_labels):
|
|
return ImageClassifier(classifier, preprocessor=BasicPreprocessor())
|
|
|
|
|
|
@pytest.fixture
|
|
def classifier(estimator_adapter, label_mapper):
|
|
classifier = Classifier(estimator_adapter, label_mapper)
|
|
return classifier
|
|
|
|
|
|
@pytest.fixture
|
|
def estimator_mock():
|
|
class EstimatorMock:
|
|
@staticmethod
|
|
def predict(batch):
|
|
return [None for _ in batch]
|
|
|
|
@staticmethod
|
|
def predict_proba(batch):
|
|
return [None for _ in batch]
|
|
|
|
def __call__(self, batch):
|
|
return self.predict(batch)
|
|
|
|
return EstimatorMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def label_mapper(label_format, classes):
|
|
if label_format == "index":
|
|
return IndexMapper(classes)
|
|
elif label_format == "probability":
|
|
return ProbabilityMapper(classes)
|
|
else:
|
|
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
|
|
|
|
|
|
@pytest.fixture(params=["index"])
|
|
def label_format(request):
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture
|
|
def expected_predictions_mapped(
|
|
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
|
|
):
|
|
if label_format == "index":
|
|
return batch_of_expected_string_labels
|
|
elif label_format == "probability":
|
|
return batch_of_expected_label_to_probability_mappings
|
|
else:
|
|
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
|
|
|
|
|
|
@pytest.fixture
|
|
def expected_predictions(label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays):
|
|
if label_format == "index":
|
|
return batch_of_expected_numeric_labels
|
|
elif label_format == "probability":
|
|
return batch_of_expected_probability_arrays
|
|
else:
|
|
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
|
|
|
|
|
|
@pytest.fixture
|
|
def estimator_adapter(
|
|
estimator_type, estimator_mock, keras_model, model_handle_mock, output_batch_generator, monkeypatch
|
|
):
|
|
if estimator_type == "mock":
|
|
estimator_adapter = EstimatorAdapter(estimator_mock)
|
|
elif estimator_type == "keras":
|
|
estimator_adapter = EstimatorAdapter(keras_model)
|
|
elif estimator_type == "redai":
|
|
estimator_adapter = EstimatorAdapter(PredictionModelHandle(model_handle_mock))
|
|
else:
|
|
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
|
|
|
|
def mock_predict(batch):
|
|
# Run real predict function to test for mechanical issues, but return externally defined
|
|
# predictions to test the callers of the estimator adapter against the expected predictions
|
|
return [next(output_batch_generator) for _ in _predict(batch)]
|
|
|
|
_predict = estimator_adapter.predict
|
|
monkeypatch.setattr(estimator_adapter, "predict", mock_predict)
|
|
|
|
return estimator_adapter
|
|
|
|
|
|
@pytest.fixture
|
|
def keras_model(input_size):
|
|
import os
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
|
import tensorflow as tf
|
|
|
|
tf.keras.backend.set_image_data_format("channels_last")
|
|
|
|
inputs = tf.keras.Input(shape=input_size)
|
|
conv = tf.keras.layers.Conv2D(3, 3)
|
|
dense = tf.keras.layers.Dense(10)
|
|
|
|
outputs = tf.keras.layers.Dense(10)(dense(conv(inputs)))
|
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
|
model.compile()
|
|
|
|
return model
|
|
|
|
|
|
@pytest.fixture
|
|
def images(input_batch):
|
|
return list(map(array_to_image, input_batch))
|
|
|
|
|
|
@pytest.fixture
|
|
def input_batch(batch_size, input_size):
|
|
return np.random.random_sample(size=(batch_size, *input_size))
|
|
|
|
|
|
@pytest.fixture(params=[0, 1, 2, 16, 32])
|
|
def batch_size(request):
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(params=[False])
|
|
def input_size(request, __input_size):
|
|
alpha = request.param
|
|
print(alpha)
|
|
if alpha:
|
|
w, h, d = __input_size
|
|
__input_size = w, h, d + 1
|
|
return __input_size
|
|
|
|
|
|
@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}])
|
|
def __input_size(request):
|
|
return itemgetter("width", "height", "depth")(request.param)
|
|
|
|
|
|
def array_to_image(array):
|
|
assert np.all(array <= 1)
|
|
assert np.all(array >= 0)
|
|
|
|
if array.shape[-1] == 3:
|
|
mode = "RGB"
|
|
elif array.shape[-1] == 4:
|
|
mode = "RGBA"
|
|
else:
|
|
raise ValueError(f"Unexpected number of channels {array.shape[-1]}. Expected 3 or 4.")
|
|
|
|
# noinspection PyTypeChecker
|
|
return Image.fromarray(np.uint8(array * 255), mode=mode)
|
|
|
|
|
|
@pytest.fixture
|
|
def batch_of_expected_string_labels(batch_of_expected_numeric_labels, classes):
|
|
return map_labels(batch_of_expected_numeric_labels, classes)
|
|
|
|
|
|
@pytest.fixture
|
|
def batch_of_expected_numeric_labels(batch_size, classes):
|
|
return random.choices(range(len(classes)), k=batch_size)
|
|
|
|
|
|
@pytest.fixture
|
|
def batch_of_expected_label_to_probability_mappings(batch_of_expected_probability_arrays, classes):
|
|
def map_probabilities(probabilities):
|
|
lbl2prob = dict(sorted(zip(classes, map(rounder, probabilities)), key=itemgetter(1), reverse=True))
|
|
most_likely = [*lbl2prob][0]
|
|
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob}
|
|
|
|
rounder = rcompose(partial(np.round, decimals=4), float)
|
|
return list(map(map_probabilities, batch_of_expected_probability_arrays))
|
|
|
|
|
|
@pytest.fixture
|
|
def batch_of_expected_probability_arrays(batch_size, classes):
|
|
return [np.random.uniform(size=len(classes)) for _ in range(batch_size)]
|
|
|
|
|
|
@pytest.fixture
|
|
def output_batch_generator(expected_predictions):
|
|
return iter(expected_predictions)
|
|
|
|
|
|
@pytest.fixture
|
|
def classes():
|
|
return ["A", "B", "C"]
|
|
|
|
|
|
def map_labels(numeric_labels, classes):
|
|
return [classes[nl] for nl in numeric_labels]
|
|
|
|
|
|
@pytest.fixture
|
|
def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata):
|
|
return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
|
|
|
|
|
|
@pytest.fixture
|
|
def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted):
|
|
return [
|
|
{"classification": epm, **mdt}
|
|
for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted)
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
|
|
return [{k.value: v for k, v in epm.items()} for epm in expected_predictions_mapped]
|
|
|
|
|
|
@pytest.fixture
|
|
def metadata(images, info_label_map):
|
|
page_idx = 0
|
|
|
|
def current_page_idx():
|
|
nonlocal page_idx
|
|
page_idx += random.randint(0, 3)
|
|
return min(page_idx, len(images) - 1)
|
|
|
|
def build_image_metadata(image):
|
|
width, height = image.size
|
|
page_width = 595
|
|
page_height = 842
|
|
x1 = random.randint(0, page_width - width)
|
|
x2 = x1 + width
|
|
y1 = random.randint(0, page_height - height)
|
|
y2 = y1 + height
|
|
metadata = {
|
|
info_label_map.PAGE_WIDTH: page_width,
|
|
info_label_map.PAGE_HEIGHT: page_height,
|
|
info_label_map.PAGE_IDX: current_page_idx(),
|
|
info_label_map.WIDTH: width,
|
|
info_label_map.HEIGHT: height,
|
|
info_label_map.X1: x1,
|
|
info_label_map.X2: x2,
|
|
info_label_map.Y1: y1,
|
|
info_label_map.Y2: y2,
|
|
info_label_map.ALPHA: image.mode == "RGBA"
|
|
}
|
|
return metadata
|
|
|
|
return list(map(build_image_metadata, images))
|
|
|
|
|
|
@pytest.fixture
|
|
def info_label_map():
|
|
return Info
|
|
|
|
|
|
@pytest.fixture
|
|
def metadata_formatted(metadata):
|
|
def format_metadata(metadata):
|
|
return {key.value: val for key, val in metadata.items()}
|
|
|
|
return list(map(format_metadata, metadata))
|
|
|
|
|
|
@pytest.fixture
|
|
def image_metadata_pairs(images, metadata):
|
|
return list(starmap(ImageMetadataPair, zip(images, metadata)))
|
|
|
|
|
|
@pytest.fixture
|
|
def pdf(image_metadata_pairs):
|
|
pdf = fpdf.FPDF(unit="pt")
|
|
|
|
for pair in image_metadata_pairs:
|
|
add_image(pdf, pair)
|
|
|
|
return pdf_stream(pdf)
|
|
|
|
|
|
def add_image(pdf, image_metadata_pair):
|
|
while fewer_pages_then_required(image_metadata_pair.metadata[Info.PAGE_IDX], pdf):
|
|
pdf.add_page()
|
|
|
|
add_image_to_last_page(pdf, image_metadata_pair)
|
|
|
|
|
|
def fewer_pages_then_required(page_idx, pdf):
|
|
return page_idx > pdf.page - 1
|
|
|
|
|
|
def pdf_stream(pdf: fpdf.fpdf.FPDF):
|
|
return pdf.output(dest="S").encode("latin1")
|
|
|
|
|
|
def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair):
|
|
image, metadata = image_metadata_pair
|
|
x, y, w, h = itemgetter(Info.X1, Info.Y1, Info.WIDTH, Info.HEIGHT)(metadata)
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png") as temp_image:
|
|
image.save(temp_image.name)
|
|
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type="png")
|
|
|
|
|
|
@pytest.fixture
|
|
def model():
|
|
class Model:
|
|
@staticmethod
|
|
def predict(*args):
|
|
return True
|
|
|
|
@staticmethod
|
|
def predict_proba(*args):
|
|
return True
|
|
|
|
return Model()
|
|
|
|
|
|
@pytest.fixture
|
|
def model_database_record_identifier():
|
|
return "".join(random.sample(string.ascii_letters, k=10))
|
|
|
|
|
|
@pytest.fixture
|
|
def model_database_record(model, classes):
|
|
return {"model": model, "classes": classes}
|
|
|
|
|
|
@pytest.fixture
|
|
def model_database(model_database_record, model_database_record_identifier):
|
|
return {model_database_record_identifier: model_database_record}
|
|
|
|
|
|
@pytest.fixture
|
|
def database_connector(database_type, model_database, mlflow_reader):
|
|
if database_type == "mock":
|
|
return DatabaseConnectorMock(model_database)
|
|
|
|
elif database_type == "mlflow":
|
|
return MlflowConnector(mlflow_reader)
|
|
|
|
else:
|
|
raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.")
|
|
|
|
|
|
@pytest.fixture
|
|
def model_loader(database_connector):
|
|
return ModelLoader(database_connector)
|
|
|
|
|
|
@pytest.fixture
|
|
def mlflow_run_id():
|
|
from image_prediction.config import CONFIG
|
|
|
|
return CONFIG.service.run_id
|
|
|
|
|
|
@pytest.fixture
|
|
def mlruns_dir():
|
|
from image_prediction.locations import MLRUNS_DIR
|
|
|
|
return MLRUNS_DIR
|
|
|
|
|
|
@pytest.fixture
|
|
def mlflow_reader(mlruns_dir):
|
|
return MlflowModelReader(mlruns_dir)
|
|
|
|
|
|
@pytest.fixture
|
|
def model_handle_mock(estimator_mock):
|
|
class ModelHandleMock:
|
|
def __init__(self):
|
|
self.model = estimator_mock
|
|
|
|
def prep_images(self, batch):
|
|
return [None for _ in batch]
|
|
|
|
def predict(self, batch):
|
|
return [None for _ in batch]
|
|
|
|
def predict_proba(self, batch):
|
|
return [None for _ in batch]
|
|
|
|
return ModelHandleMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def real_pdf():
|
|
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f:
|
|
yield f.read()
|
|
|
|
|
|
@pytest.fixture
|
|
def real_expected_service_response():
|
|
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
|
|
yield json.load(f)
|
|
|
|
|
|
@pytest.fixture
|
|
def pipeline():
|
|
pipeline = load_pipeline(verbose=False)
|
|
return pipeline
|
|
|
|
|
|
def transform_equal(a, b):
|
|
return (list(a) if isinstance(a, map) else a) == b
|
|
|
|
|
|
def get_base_position_metadata(width, height, page_width, page_height):
|
|
return {
|
|
Info.WIDTH: width,
|
|
Info.HEIGHT: height,
|
|
Info.PAGE_IDX: 0,
|
|
Info.PAGE_WIDTH: page_width,
|
|
Info.PAGE_HEIGHT: page_height,
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def base_patch_metadata(width, height, page_width, page_height):
|
|
metadata = get_base_position_metadata(width, height, page_width, page_height)
|
|
metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height})
|
|
return metadata
|
|
|
|
|
|
@pytest.fixture(params=[33, 100])
|
|
def height(request):
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(params=[10, 31])
|
|
def width(request):
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(params=[220, 30])
|
|
def page_height(request):
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(params=[100, 310])
|
|
def page_width(request):
|
|
return request.param
|
|
|
|
|
|
def random_single_color_image_from_metadata(metadata):
|
|
image = Image.new(
|
|
"RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=tuple(map(int, np.random.uniform(size=3) * 255))
|
|
)
|
|
return image
|
|
|
|
|
|
# TODO: rename: not random!
|
|
def random_size_gray_image_from_metadata(metadata):
|
|
image = Image.new("RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=(100, 100, 100))
|
|
return image
|