Matthias Bisping 86f2abc553 renaming
2022-03-28 18:52:39 +02:00

204 lines
5.5 KiB
Python

import random
import tempfile
from itertools import starmap
from operator import itemgetter
import fpdf
import numpy as np
import pytest
from PIL import Image
from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
@pytest.fixture
def image_extractor(extractor_type):
if extractor_type == "mock":
return ImageExtractorMock()
elif extractor_type == "parsable_pdf":
return ParsablePDFImageExtractor()
else:
raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.")
@pytest.fixture
def image_classifier(classifier, monkeypatch, expected_predictions):
return ImageClassifier(classifier)
@pytest.fixture
def classifier(estimator_adapter, classes):
classifier = Classifier(estimator_adapter, classes)
return classifier
@pytest.fixture
def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch):
if estimator_type == "mock":
estimator_adapter = EstimatorAdapterMock(EstimatorMock())
elif estimator_type == "keras":
estimator_adapter = KerasEstimatorAdapter(keras_model)
else:
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
def mock_predict(batch):
# Run real predict function to test for mechanical issues, but return externally defined
# predictions to test the callers of the estimator adapter against the expected predictions
return [next(output_batch_generator) for _ in _predict(batch)]
_predict = estimator_adapter.predict
monkeypatch.setattr(estimator_adapter, "predict", mock_predict)
return estimator_adapter
@pytest.fixture
def keras_model(input_size):
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
tf.keras.backend.set_image_data_format("channels_last")
inputs = tf.keras.Input(shape=input_size)
conv = tf.keras.layers.Conv2D(3, 3)
dense = tf.keras.layers.Dense(10)
outputs = tf.keras.layers.Dense(10)(dense(conv(inputs)))
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.compile()
return model
@pytest.fixture
def images(input_batch):
return list(map(array_to_image, input_batch))
@pytest.fixture
def input_batch(batch_size, input_size):
return np.random.random_sample(size=(batch_size, *input_size))
def array_to_image(array):
assert np.all(array <= 1)
assert np.all(array >= 0)
return Image.fromarray(np.uint8(array * 255), mode="RGB")
@pytest.fixture
def input_size(depth=3, width=10, height=15):
return width, height, depth
@pytest.fixture
def expected_predictions(output_batch, classes):
return map_labels(output_batch, classes)
@pytest.fixture
def output_batch(input_batch, classes):
return random.choices(range(len(classes)), k=len(input_batch))
@pytest.fixture
def output_batch_generator(output_batch):
return iter(output_batch)
@pytest.fixture
def classes():
return ["A", "B", "C"]
def map_labels(numeric_labels, classes):
return [classes[nl] for nl in numeric_labels]
@pytest.fixture
def metadata(images):
page_idx = 0
def current_page_idx():
nonlocal page_idx
page_idx += random.randint(0, 2)
return min(page_idx, len(images) - 1)
def build_image_metadata(image):
width, height = image.size
page_width = 595
page_height = 842
x1 = random.randint(0, page_width - width)
x2 = x1 + width
y1 = random.randint(0, page_height - height)
y2 = y1 + height
metadata = {
"page_width": page_width,
"page_height": page_height,
"page_idx": current_page_idx(),
"width": width,
"height": height,
"x1": x1,
"x2": x2,
"y1": y1,
"y2": y2
}
return metadata
return list(map(build_image_metadata, images))
@pytest.fixture
def image_metadata_pairs(images, metadata):
return list(starmap(ImageMetadataPair, zip(images, metadata)))
@pytest.fixture
def pdf(image_metadata_pairs):
pdf = fpdf.FPDF(unit="pt")
for pair in image_metadata_pairs:
add_image(pdf, pair)
return pdf_stream(pdf)
def add_image(pdf, image_metadata_pair):
while fewer_pages_then_required(image_metadata_pair.metadata["page_idx"], pdf):
pdf.add_page()
add_image_to_last_page(pdf, image_metadata_pair)
def fewer_pages_then_required(page_idx, pdf):
return page_idx > pdf.page - 1
def pdf_stream(pdf: fpdf.fpdf.FPDF):
return pdf.output(dest="S").encode("latin1")
def add_image_to_last_page(pdf, image_metadata_pair):
image, metadata = image_metadata_pair
x, y, w, h = itemgetter("x1", "y1", "width", "height")(metadata)
with tempfile.NamedTemporaryFile(suffix=".png") as temp_image:
image.save(temp_image.name)
pdf.image(temp_image.name, x=x, y=y, w=w, h=h)