2022-04-14 18:49:18 +02:00

163 lines
4.7 KiB
Python

import json
import logging
import os
import random
from functools import partial
from itertools import starmap
from operator import itemgetter
import fpdf
import numpy as np
import pytest
from funcy import rcompose
from image_prediction.exceptions import (
UnknownLabelFormat,
)
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.label_mapper.mappers.probability import ProbabilityMapperKeys
from image_prediction.locations import TEST_DATA_DIR
from image_prediction.pipeline import load_pipeline
from image_prediction.utils import get_logger
from test.utils.generation.pdf import add_image, pdf_stream
from test.utils.label import map_labels
pytest_plugins = [
"test.fixtures.image",
"test.fixtures.input",
"test.fixtures.label",
"test.fixtures.metadata",
"test.fixtures.model",
"test.fixtures.model_store",
"test.fixtures.parameters",
]
@pytest.fixture(autouse=True)
def mute_logger():
logger = get_logger()
level = logger.level
logger.setLevel(logging.CRITICAL + 1)
yield
logger.setLevel(level)
@pytest.fixture
def expected_predictions_mapped(
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
):
if label_format == "index":
return batch_of_expected_string_labels
elif label_format == "probability":
return batch_of_expected_label_to_probability_mappings
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture
def expected_predictions(label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays):
if label_format == "index":
return batch_of_expected_numeric_labels
elif label_format == "probability":
return batch_of_expected_probability_arrays
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture
def input_batch(batch_size, input_size):
return np.random.random_sample(size=(batch_size, *input_size))
@pytest.fixture(params=[0, 1, 2, 16, 32])
def batch_size(request):
return request.param
@pytest.fixture
def batch_of_expected_string_labels(batch_of_expected_numeric_labels, classes):
return map_labels(batch_of_expected_numeric_labels, classes)
@pytest.fixture
def batch_of_expected_numeric_labels(batch_size, classes):
return random.choices(range(len(classes)), k=batch_size)
@pytest.fixture
def batch_of_expected_label_to_probability_mappings(batch_of_expected_probability_arrays, classes):
def map_probabilities(probabilities):
lbl2prob = dict(sorted(zip(classes, map(rounder, probabilities)), key=itemgetter(1), reverse=True))
most_likely = [*lbl2prob][0]
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob}
rounder = rcompose(partial(np.round, decimals=4), float)
return list(map(map_probabilities, batch_of_expected_probability_arrays))
@pytest.fixture
def batch_of_expected_probability_arrays(batch_size, classes):
return [np.random.uniform(size=len(classes)) for _ in range(batch_size)]
@pytest.fixture
def output_batch_generator(expected_predictions):
return iter(expected_predictions)
@pytest.fixture
def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata):
return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
@pytest.fixture
def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted):
return [
{"classification": epm, **mdt}
for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted)
]
@pytest.fixture
def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
return [{k.value: v for k, v in epm.items()} for epm in expected_predictions_mapped]
@pytest.fixture
def info_label_map():
return Info
@pytest.fixture
def image_metadata_pairs(images, metadata):
return list(starmap(ImageMetadataPair, zip(images, metadata)))
@pytest.fixture
def pdf(image_metadata_pairs):
pdf = fpdf.FPDF(unit="pt")
for pair in image_metadata_pairs:
add_image(pdf, pair)
return pdf_stream(pdf)
@pytest.fixture
def real_pdf():
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f:
yield f.read()
@pytest.fixture
def real_expected_service_response():
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
yield json.load(f)
@pytest.fixture
def pipeline():
pipeline = load_pipeline(verbose=False)
return pipeline