2022-04-14 19:00:12 +02:00

91 lines
3.1 KiB
Python

import json
import os
import random
from functools import partial
from operator import itemgetter
import numpy as np
import pytest
from funcy import rcompose
from image_prediction.exceptions import UnknownLabelFormat
from image_prediction.label_mapper.mappers.probability import ProbabilityMapperKeys
from image_prediction.locations import TEST_DATA_DIR
from test.utils.label import map_labels
@pytest.fixture
def expected_predictions_mapped(
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
):
if label_format == "index":
return batch_of_expected_string_labels
elif label_format == "probability":
return batch_of_expected_label_to_probability_mappings
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture
def expected_predictions(label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays):
if label_format == "index":
return batch_of_expected_numeric_labels
elif label_format == "probability":
return batch_of_expected_probability_arrays
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture
def batch_of_expected_string_labels(batch_of_expected_numeric_labels, classes):
return map_labels(batch_of_expected_numeric_labels, classes)
@pytest.fixture
def batch_of_expected_numeric_labels(batch_size, classes):
return random.choices(range(len(classes)), k=batch_size)
@pytest.fixture
def batch_of_expected_label_to_probability_mappings(batch_of_expected_probability_arrays, classes):
def map_probabilities(probabilities):
lbl2prob = dict(sorted(zip(classes, map(rounder, probabilities)), key=itemgetter(1), reverse=True))
most_likely = [*lbl2prob][0]
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob}
rounder = rcompose(partial(np.round, decimals=4), float)
return list(map(map_probabilities, batch_of_expected_probability_arrays))
@pytest.fixture
def batch_of_expected_probability_arrays(batch_size, classes):
return [np.random.uniform(size=len(classes)) for _ in range(batch_size)]
@pytest.fixture
def output_batch_generator(expected_predictions):
return iter(expected_predictions)
@pytest.fixture
def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata):
return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
@pytest.fixture
def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted):
return [
{"classification": epm, **mdt}
for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted)
]
@pytest.fixture
def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
return [{k.value: v for k, v in epm.items()} for epm in expected_predictions_mapped]
@pytest.fixture
def real_expected_service_response():
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
yield json.load(f)