2022-04-14 18:36:13 +02:00

224 lines
6.4 KiB
Python

import json
import logging
import os
import random
from functools import partial
from itertools import starmap
from operator import itemgetter
import fpdf
import numpy as np
import pytest
from funcy import rcompose, merge
from image_prediction.exceptions import (
UnknownLabelFormat,
)
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.label_mapper.mappers.probability import ProbabilityMapperKeys
from image_prediction.locations import TEST_DATA_DIR
from image_prediction.pipeline import load_pipeline
from image_prediction.utils import get_logger
from test.utils.generation.pdf import add_image, pdf_stream
from test.utils.label import map_labels
pytest_plugins = [
'test.fixtures.model',
'test.fixtures.model_store',
'test.fixtures.image',
'test.fixtures.input',
'test.fixtures.parameters',
'test.fixtures.label',
]
@pytest.fixture(autouse=True)
def mute_logger():
logger = get_logger()
level = logger.level
logger.setLevel(logging.CRITICAL + 1)
yield
logger.setLevel(level)
@pytest.fixture
def expected_predictions_mapped(
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
):
if label_format == "index":
return batch_of_expected_string_labels
elif label_format == "probability":
return batch_of_expected_label_to_probability_mappings
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture
def expected_predictions(label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays):
if label_format == "index":
return batch_of_expected_numeric_labels
elif label_format == "probability":
return batch_of_expected_probability_arrays
else:
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
@pytest.fixture
def input_batch(batch_size, input_size):
return np.random.random_sample(size=(batch_size, *input_size))
@pytest.fixture(params=[0, 1, 2, 16, 32])
def batch_size(request):
return request.param
@pytest.fixture
def batch_of_expected_string_labels(batch_of_expected_numeric_labels, classes):
return map_labels(batch_of_expected_numeric_labels, classes)
@pytest.fixture
def batch_of_expected_numeric_labels(batch_size, classes):
return random.choices(range(len(classes)), k=batch_size)
@pytest.fixture
def batch_of_expected_label_to_probability_mappings(batch_of_expected_probability_arrays, classes):
def map_probabilities(probabilities):
lbl2prob = dict(sorted(zip(classes, map(rounder, probabilities)), key=itemgetter(1), reverse=True))
most_likely = [*lbl2prob][0]
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob}
rounder = rcompose(partial(np.round, decimals=4), float)
return list(map(map_probabilities, batch_of_expected_probability_arrays))
@pytest.fixture
def batch_of_expected_probability_arrays(batch_size, classes):
return [np.random.uniform(size=len(classes)) for _ in range(batch_size)]
@pytest.fixture
def output_batch_generator(expected_predictions):
return iter(expected_predictions)
@pytest.fixture
def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata):
return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
@pytest.fixture
def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted):
return [
{"classification": epm, **mdt}
for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted)
]
@pytest.fixture
def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
return [{k.value: v for k, v in epm.items()} for epm in expected_predictions_mapped]
@pytest.fixture
def metadata(images, info_label_map):
page_idx = 0
def current_page_idx():
nonlocal page_idx
page_idx += random.randint(0, 3)
return min(page_idx, len(images) - 1)
def build_image_metadata(image):
width, height = image.size
page_width = 595
page_height = 842
x1 = random.randint(0, page_width - width)
x2 = x1 + width
y1 = random.randint(0, page_height - height)
y2 = y1 + height
metadata = {
info_label_map.PAGE_WIDTH: page_width,
info_label_map.PAGE_HEIGHT: page_height,
info_label_map.PAGE_IDX: current_page_idx(),
info_label_map.WIDTH: width,
info_label_map.HEIGHT: height,
info_label_map.X1: x1,
info_label_map.X2: x2,
info_label_map.Y1: y1,
info_label_map.Y2: y2,
info_label_map.ALPHA: image.mode == "RGBA",
}
return metadata
return list(map(build_image_metadata, images))
@pytest.fixture
def info_label_map():
return Info
@pytest.fixture
def metadata_formatted(metadata):
def format_metadata(metadata):
return {key.value: val for key, val in metadata.items()}
return list(map(format_metadata, metadata))
@pytest.fixture
def image_metadata_pairs(images, metadata):
return list(starmap(ImageMetadataPair, zip(images, metadata)))
@pytest.fixture
def pdf(image_metadata_pairs):
pdf = fpdf.FPDF(unit="pt")
for pair in image_metadata_pairs:
add_image(pdf, pair)
return pdf_stream(pdf)
@pytest.fixture
def real_pdf():
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f:
yield f.read()
@pytest.fixture
def real_expected_service_response():
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
yield json.load(f)
@pytest.fixture
def pipeline():
pipeline = load_pipeline(verbose=False)
return pipeline
def get_base_position_metadata(width, height, page_width, page_height):
return {
Info.WIDTH: width,
Info.HEIGHT: height,
Info.PAGE_IDX: 0,
Info.PAGE_WIDTH: page_width,
Info.PAGE_HEIGHT: page_height,
}
@pytest.fixture
def base_patch_metadata(width, height, page_width, page_height):
metadata = get_base_position_metadata(width, height, page_width, page_height)
metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height})
return metadata