194 lines
5.2 KiB
Python
194 lines
5.2 KiB
Python
import io
|
|
import random
|
|
import tempfile
|
|
from operator import itemgetter
|
|
|
|
import fpdf
|
|
import numpy as np
|
|
import pytest
|
|
from PIL import Image
|
|
|
|
from image_prediction.classifier.classifier import Classifier
|
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
|
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
|
|
from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock
|
|
from image_prediction.exceptions import UnknownEstimatorAdapter, UnknownImageExtractor
|
|
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
|
|
|
|
|
@pytest.fixture
|
|
def image_extractor(extractor_type):
|
|
if extractor_type == "mock":
|
|
return ImageExtractorMock()
|
|
elif extractor_type == "parsable_pdf":
|
|
return ParsablePDFImageExtractor()
|
|
else:
|
|
raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.")
|
|
|
|
|
|
@pytest.fixture
|
|
def image_classifier(classifier, monkeypatch, expected_predictions):
|
|
return ImageClassifier(classifier)
|
|
|
|
|
|
@pytest.fixture
|
|
def classifier(estimator_adapter, classes):
|
|
classifier = Classifier(estimator_adapter, classes)
|
|
return classifier
|
|
|
|
|
|
@pytest.fixture
|
|
def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch):
|
|
if estimator_type == "mock":
|
|
estimator_adapter = EstimatorAdapterMock(EstimatorMock())
|
|
elif estimator_type == "keras":
|
|
estimator_adapter = KerasEstimatorAdapter(keras_model)
|
|
else:
|
|
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
|
|
|
|
def mock_predict(batch):
|
|
# Run real predict function to test for mechanical issues, but return externally defined
|
|
# predictions to test the callers of the estimator adapter against the expected predictions
|
|
return [next(output_batch_generator) for _ in _predict(batch)]
|
|
|
|
_predict = estimator_adapter.predict
|
|
monkeypatch.setattr(estimator_adapter, "predict", mock_predict)
|
|
|
|
return estimator_adapter
|
|
|
|
|
|
@pytest.fixture
|
|
def keras_model(input_size):
|
|
import warnings
|
|
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
|
|
import os
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
|
import tensorflow as tf
|
|
|
|
tf.keras.backend.set_image_data_format("channels_last")
|
|
|
|
inputs = tf.keras.Input(shape=input_size)
|
|
conv = tf.keras.layers.Conv2D(3, 3)
|
|
dense = tf.keras.layers.Dense(10)
|
|
|
|
outputs = tf.keras.layers.Dense(10)(dense(conv(inputs)))
|
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
|
model.compile()
|
|
|
|
return model
|
|
|
|
|
|
@pytest.fixture
|
|
def images(input_batch):
|
|
return list(map(array_to_image, input_batch))
|
|
|
|
|
|
@pytest.fixture
|
|
def input_batch(batch_size, input_size):
|
|
return np.random.random_sample(size=(batch_size, *input_size))
|
|
|
|
|
|
def array_to_image(array):
|
|
assert np.all(array <= 1)
|
|
assert np.all(array >= 0)
|
|
return Image.fromarray(np.uint8(array * 255), mode="RGB")
|
|
|
|
|
|
@pytest.fixture
|
|
def input_size(depth=3, width=10, height=15):
|
|
return width, height, depth
|
|
|
|
|
|
@pytest.fixture
|
|
def expected_predictions(output_batch, classes):
|
|
return map_labels(output_batch, classes)
|
|
|
|
|
|
@pytest.fixture
|
|
def output_batch(input_batch, classes):
|
|
return random.choices(range(len(classes)), k=len(input_batch))
|
|
|
|
|
|
@pytest.fixture
|
|
def output_batch_generator(output_batch):
|
|
return iter(output_batch)
|
|
|
|
|
|
@pytest.fixture
|
|
def classes():
|
|
return ["A", "B", "C"]
|
|
|
|
|
|
def map_labels(numeric_labels, classes):
|
|
return [classes[nl] for nl in numeric_labels]
|
|
|
|
|
|
@pytest.fixture
|
|
def metadata(images):
|
|
page_idx = 0
|
|
|
|
def current_page_idx():
|
|
nonlocal page_idx
|
|
page_idx += random.randint(0, 2)
|
|
return min(page_idx, len(images) - 1)
|
|
|
|
def build_image_metadata(image):
|
|
width, height = image.size
|
|
page_width = 595
|
|
page_height = 842
|
|
x1 = random.randint(0, page_width - width)
|
|
x2 = x1 + width
|
|
y1 = random.randint(0, page_height - height)
|
|
y2 = y1 + height
|
|
metadata = {
|
|
"page_width": page_width,
|
|
"page_height": page_height,
|
|
"page_idx": current_page_idx(),
|
|
"width": width,
|
|
"height": height,
|
|
"x1": x1,
|
|
"x2": x2,
|
|
"y1": y1,
|
|
"y2": y2
|
|
}
|
|
return metadata
|
|
|
|
return list(map(build_image_metadata, images))
|
|
|
|
|
|
@pytest.fixture
|
|
def pdf(images, metadata):
|
|
|
|
def add_image(image, metadata):
|
|
|
|
def fewer_pages_then_required():
|
|
return metadata["page_idx"] > pdf.page - 1
|
|
|
|
def add_image_to_last_page():
|
|
x, y, w, h = itemgetter("x1", "y1", "width", "height")(metadata)
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png") as temp_image:
|
|
image.save(temp_image.name)
|
|
pdf.image(temp_image.name, x=x, y=y, w=w, h=h)
|
|
|
|
while fewer_pages_then_required():
|
|
pdf.add_page()
|
|
|
|
add_image_to_last_page()
|
|
|
|
def pdf_object_to_actual_pdf():
|
|
return pdf.output(dest="S").encode("latin1")
|
|
|
|
pdf = fpdf.FPDF(unit="pt")
|
|
pdf.add_page()
|
|
|
|
for image, metadata in zip(images, metadata):
|
|
add_image(image, metadata)
|
|
|
|
return pdf_object_to_actual_pdf()
|