import random import string import tempfile from itertools import starmap from operator import itemgetter import fpdf import numpy as np import pytest from PIL import Image from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor, UnknownDatabaseType from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractors.mock import ImageExtractorMock from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.info import Info from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loaders.mlflow import MlflowConnector from image_prediction.redai_adapter.mlflow import MlflowModelReader @pytest.fixture def image_extractor(extractor_type): if extractor_type == "mock": return ImageExtractorMock() elif extractor_type == "parsable_pdf": return ParsablePDFImageExtractor() elif extractor_type == "default": return None else: raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.") @pytest.fixture def image_classifier(classifier, monkeypatch, expected_predictions): return ImageClassifier(classifier, preprocessor=BasicPreprocessor()) @pytest.fixture def classifier(estimator_adapter, classes): classifier = Classifier(estimator_adapter, classes) return classifier class EstimatorMock: @staticmethod def predict(batch): return [None for _ in batch] def __call__(self, batch): return self.predict(batch) @pytest.fixture def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch): if estimator_type == "mock": estimator_adapter = EstimatorAdapter(EstimatorMock()) elif estimator_type == "keras": estimator_adapter = EstimatorAdapter(keras_model) else: raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.") def mock_predict(batch): # Run real predict function to test for mechanical issues, but return externally defined # predictions to test the callers of the estimator adapter against the expected predictions return [next(output_batch_generator) for _ in _predict(batch)] _predict = estimator_adapter.predict monkeypatch.setattr(estimator_adapter, "predict", mock_predict) return estimator_adapter @pytest.fixture def keras_model(input_size): import os os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf tf.keras.backend.set_image_data_format("channels_last") inputs = tf.keras.Input(shape=input_size) conv = tf.keras.layers.Conv2D(3, 3) dense = tf.keras.layers.Dense(10) outputs = tf.keras.layers.Dense(10)(dense(conv(inputs))) model = tf.keras.Model(inputs=inputs, outputs=outputs) model.compile() return model @pytest.fixture def images(input_batch): return list(map(array_to_image, input_batch)) @pytest.fixture def input_batch(batch_size, input_size): return np.random.random_sample(size=(batch_size, *input_size)) @pytest.fixture(params=[0, 1, 2, 16, 32]) def batch_size(request): return request.param @pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}]) def input_size(request): return itemgetter("width", "height", "depth")(request.param) def array_to_image(array): assert np.all(array <= 1) assert np.all(array >= 0) return Image.fromarray(np.uint8(array * 255), mode="RGB") @pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}]) def input_size(request): return itemgetter("width", "height", "depth")(request.param) @pytest.fixture def expected_predictions(output_batch, classes): return map_labels(output_batch, classes) @pytest.fixture def output_batch(input_batch, classes): return random.choices(range(len(classes)), k=len(input_batch)) @pytest.fixture def output_batch_generator(output_batch): return iter(output_batch) @pytest.fixture def classes(): return ["A", "B", "C"] def map_labels(numeric_labels, classes): return [classes[nl] for nl in numeric_labels] @pytest.fixture def metadata(images, info_label_map): page_idx = 0 def current_page_idx(): nonlocal page_idx page_idx += random.randint(0, 3) return min(page_idx, len(images) - 1) def build_image_metadata(image): width, height = image.size page_width = 595 page_height = 842 x1 = random.randint(0, page_width - width) x2 = x1 + width y1 = random.randint(0, page_height - height) y2 = y1 + height metadata = { info_label_map.PAGE_WIDTH: page_width, info_label_map.PAGE_HEIGHT: page_height, info_label_map.PAGE_IDX: current_page_idx(), info_label_map.WIDTH: width, info_label_map.HEIGHT: height, info_label_map.X1: x1, info_label_map.X2: x2, info_label_map.Y1: y1, info_label_map.Y2: y2, } return metadata return list(map(build_image_metadata, images)) @pytest.fixture def info_label_map(): return Info @pytest.fixture def metadata_formatted(metadata): def format_metadata(metadata): return {key.value: val for key, val in metadata.items()} return list(map(format_metadata, metadata)) @pytest.fixture def image_metadata_pairs(images, metadata): return list(starmap(ImageMetadataPair, zip(images, metadata))) @pytest.fixture def pdf(image_metadata_pairs): pdf = fpdf.FPDF(unit="pt") for pair in image_metadata_pairs: add_image(pdf, pair) return pdf_stream(pdf) def add_image(pdf, image_metadata_pair): while fewer_pages_then_required(image_metadata_pair.metadata[Info.PAGE_IDX], pdf): pdf.add_page() add_image_to_last_page(pdf, image_metadata_pair) def fewer_pages_then_required(page_idx, pdf): return page_idx > pdf.page - 1 def pdf_stream(pdf: fpdf.fpdf.FPDF): return pdf.output(dest="S").encode("latin1") def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair): image, metadata = image_metadata_pair x, y, w, h = itemgetter(Info.X1, Info.Y1, Info.WIDTH, Info.HEIGHT)(metadata) with tempfile.NamedTemporaryFile(suffix=".png") as temp_image: image.save(temp_image.name) pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type="png") @pytest.fixture def model(): class Model: @staticmethod def predict(*args): return True @staticmethod def predict_proba(*args): return True return Model() @pytest.fixture def model_database_record_identifier(): return "".join(random.sample(string.ascii_letters, k=10)) @pytest.fixture def model_database_record(model, classes): return {"model": model, "classes": classes} @pytest.fixture def model_database(model_database_record, model_database_record_identifier): return {model_database_record_identifier: model_database_record} @pytest.fixture def database_connector(database_type, model_database, mlflow_reader): if database_type == "mock": return DatabaseConnectorMock(model_database) elif database_type == "mlflow": return MlflowConnector(mlflow_reader) else: raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.") @pytest.fixture def model_loader(database_connector): return ModelLoader(database_connector) @pytest.fixture def mlflow_run_id(): from image_prediction.config import CONFIG return CONFIG.service.run_id @pytest.fixture def mlruns_dir(): from image_prediction.locations import MLRUNS_DIR return MLRUNS_DIR @pytest.fixture def mlflow_reader(mlruns_dir): return MlflowModelReader(mlruns_dir)