import json import logging import os import random import string from functools import partial from itertools import starmap from operator import itemgetter import fpdf import numpy as np import pytest from funcy import rcompose, merge from image_prediction.exceptions import ( UnknownDatabaseType, UnknownLabelFormat, ) from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info from image_prediction.label_mapper.mappers.numeric import IndexMapper from image_prediction.label_mapper.mappers.probability import ProbabilityMapper, ProbabilityMapperKeys from image_prediction.locations import TEST_DATA_DIR from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loaders.mlflow import MlflowConnector from image_prediction.pipeline import load_pipeline from image_prediction.redai_adapter.mlflow import MlflowModelReader from image_prediction.utils import get_logger from test.utils.generation.image import array_to_image from test.utils.generation.pdf import add_image, pdf_stream pytest_plugins = ['test.utils.model'] @pytest.fixture(autouse=True) def mute_logger(): logger = get_logger() level = logger.level logger.setLevel(logging.CRITICAL + 1) yield logger.setLevel(level) @pytest.fixture def label_mapper(label_format, classes): if label_format == "index": return IndexMapper(classes) elif label_format == "probability": return ProbabilityMapper(classes) else: raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.") @pytest.fixture(params=["index"]) def label_format(request): return request.param @pytest.fixture def expected_predictions_mapped( label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings ): if label_format == "index": return batch_of_expected_string_labels elif label_format == "probability": return batch_of_expected_label_to_probability_mappings else: raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.") @pytest.fixture def expected_predictions(label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays): if label_format == "index": return batch_of_expected_numeric_labels elif label_format == "probability": return batch_of_expected_probability_arrays else: raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.") @pytest.fixture def images(input_batch): return list(map(array_to_image, input_batch)) @pytest.fixture def input_batch(batch_size, input_size): return np.random.random_sample(size=(batch_size, *input_size)) @pytest.fixture(params=[0, 1, 2, 16, 32]) def batch_size(request): return request.param @pytest.fixture def input_size(alpha, __input_size): w, h, d = __input_size return w, h, d + alpha @pytest.fixture(params=[False]) def alpha(request): return request.param @pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}]) def __input_size(request): return itemgetter("width", "height", "depth")(request.param) @pytest.fixture def batch_of_expected_string_labels(batch_of_expected_numeric_labels, classes): return map_labels(batch_of_expected_numeric_labels, classes) @pytest.fixture def batch_of_expected_numeric_labels(batch_size, classes): return random.choices(range(len(classes)), k=batch_size) @pytest.fixture def batch_of_expected_label_to_probability_mappings(batch_of_expected_probability_arrays, classes): def map_probabilities(probabilities): lbl2prob = dict(sorted(zip(classes, map(rounder, probabilities)), key=itemgetter(1), reverse=True)) most_likely = [*lbl2prob][0] return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob} rounder = rcompose(partial(np.round, decimals=4), float) return list(map(map_probabilities, batch_of_expected_probability_arrays)) @pytest.fixture def batch_of_expected_probability_arrays(batch_size, classes): return [np.random.uniform(size=len(classes)) for _ in range(batch_size)] @pytest.fixture def output_batch_generator(expected_predictions): return iter(expected_predictions) @pytest.fixture def classes(): return ["A", "B", "C"] def map_labels(numeric_labels, classes): return [classes[nl] for nl in numeric_labels] @pytest.fixture def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata): return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)] @pytest.fixture def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted): return [ {"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted) ] @pytest.fixture def expected_predictions_mapped_and_formatted(expected_predictions_mapped): return [{k.value: v for k, v in epm.items()} for epm in expected_predictions_mapped] @pytest.fixture def metadata(images, info_label_map): page_idx = 0 def current_page_idx(): nonlocal page_idx page_idx += random.randint(0, 3) return min(page_idx, len(images) - 1) def build_image_metadata(image): width, height = image.size page_width = 595 page_height = 842 x1 = random.randint(0, page_width - width) x2 = x1 + width y1 = random.randint(0, page_height - height) y2 = y1 + height metadata = { info_label_map.PAGE_WIDTH: page_width, info_label_map.PAGE_HEIGHT: page_height, info_label_map.PAGE_IDX: current_page_idx(), info_label_map.WIDTH: width, info_label_map.HEIGHT: height, info_label_map.X1: x1, info_label_map.X2: x2, info_label_map.Y1: y1, info_label_map.Y2: y2, info_label_map.ALPHA: image.mode == "RGBA", } return metadata return list(map(build_image_metadata, images)) @pytest.fixture def info_label_map(): return Info @pytest.fixture def metadata_formatted(metadata): def format_metadata(metadata): return {key.value: val for key, val in metadata.items()} return list(map(format_metadata, metadata)) @pytest.fixture def image_metadata_pairs(images, metadata): return list(starmap(ImageMetadataPair, zip(images, metadata))) @pytest.fixture def pdf(image_metadata_pairs): pdf = fpdf.FPDF(unit="pt") for pair in image_metadata_pairs: add_image(pdf, pair) return pdf_stream(pdf) @pytest.fixture def model_database_record_identifier(): return "".join(random.sample(string.ascii_letters, k=10)) @pytest.fixture def model_database_record(model, classes): return {"model": model, "classes": classes} @pytest.fixture def model_database(model_database_record, model_database_record_identifier): return {model_database_record_identifier: model_database_record} @pytest.fixture def database_connector(database_type, model_database, mlflow_reader): if database_type == "mock": return DatabaseConnectorMock(model_database) elif database_type == "mlflow": return MlflowConnector(mlflow_reader) else: raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.") @pytest.fixture def model_loader(database_connector): return ModelLoader(database_connector) @pytest.fixture def mlflow_run_id(): from image_prediction.config import CONFIG return CONFIG.service.run_id @pytest.fixture def mlruns_dir(): from image_prediction.locations import MLRUNS_DIR return MLRUNS_DIR @pytest.fixture def mlflow_reader(mlruns_dir): return MlflowModelReader(mlruns_dir) @pytest.fixture def real_pdf(): with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f: yield f.read() @pytest.fixture def real_expected_service_response(): with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f: yield json.load(f) @pytest.fixture def pipeline(): pipeline = load_pipeline(verbose=False) return pipeline def get_base_position_metadata(width, height, page_width, page_height): return { Info.WIDTH: width, Info.HEIGHT: height, Info.PAGE_IDX: 0, Info.PAGE_WIDTH: page_width, Info.PAGE_HEIGHT: page_height, } @pytest.fixture def base_patch_metadata(width, height, page_width, page_height): metadata = get_base_position_metadata(width, height, page_width, page_height) metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height}) return metadata @pytest.fixture(params=[33, 100]) def height(request): return request.param @pytest.fixture(params=[10, 31]) def width(request): return request.param @pytest.fixture(params=[220, 30]) def page_height(request): return request.param @pytest.fixture(params=[100, 310]) def page_width(request): return request.param