test data generation for image stitching
This commit is contained in:
parent
2c908162f1
commit
1fd30e68b6
@ -106,7 +106,7 @@ def label_format(request):
|
||||
|
||||
@pytest.fixture
|
||||
def expected_predictions_mapped(
|
||||
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
|
||||
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
|
||||
):
|
||||
if label_format == "index":
|
||||
return batch_of_expected_string_labels
|
||||
@ -128,7 +128,7 @@ def expected_predictions(label_format, batch_of_expected_numeric_labels, batch_o
|
||||
|
||||
@pytest.fixture
|
||||
def estimator_adapter(
|
||||
estimator_type, estimator_mock, keras_model, model_handle_mock, output_batch_generator, monkeypatch
|
||||
estimator_type, estimator_mock, keras_model, model_handle_mock, output_batch_generator, monkeypatch
|
||||
):
|
||||
if estimator_type == "mock":
|
||||
estimator_adapter = EstimatorAdapter(estimator_mock)
|
||||
@ -447,4 +447,39 @@ def pipeline():
|
||||
|
||||
|
||||
def transform_equal(a, b):
|
||||
return (list(a) if isinstance(a, map) else a) == b
|
||||
return (list(a) if isinstance(a, map) else a) == b
|
||||
|
||||
|
||||
def get_base_position_metadata(width, height, page_width, page_height):
|
||||
return {
|
||||
Info.WIDTH: width,
|
||||
Info.HEIGHT: height,
|
||||
Info.PAGE_IDX: 0,
|
||||
Info.PAGE_WIDTH: page_width,
|
||||
Info.PAGE_HEIGHT: page_height,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(params=[33, 100])
|
||||
def height(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[10, 31])
|
||||
def width(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[220, 30])
|
||||
def page_height(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[100, 310])
|
||||
def page_width(request):
|
||||
return request.param
|
||||
|
||||
|
||||
def random_single_color_image_from_metadata(metadata):
|
||||
image = Image.new('RGB', (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=tuple(map(int, np.random.uniform(size=3) * 255)))
|
||||
return image
|
||||
|
||||
@ -12,7 +12,7 @@ from image_prediction.info import Info
|
||||
from image_prediction.transformer.transformers.coordinate.fitz import FitzCoordinateTransformer
|
||||
from image_prediction.transformer.transformers.coordinate.fpdf import FPDFCoordinateTransformer
|
||||
from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer
|
||||
from test.conftest import array_to_image, add_image, transform_equal
|
||||
from test.conftest import array_to_image, add_image, transform_equal, get_base_position_metadata
|
||||
|
||||
|
||||
@pytest.mark.parametrize("coordinate_system", ["fpdf"])
|
||||
@ -53,8 +53,12 @@ def test_pdfnet_coordinate_transformer(position_metadata_in_given_system, positi
|
||||
def test_coordinate_transformer_by_metadata(
|
||||
transformer, position_metadata_in_given_system, position_metadata_in_reference_system
|
||||
):
|
||||
assert transform_equal(transformer.forward(position_metadata_in_reference_system), position_metadata_in_given_system)
|
||||
assert transform_equal(transformer.backward(position_metadata_in_given_system), position_metadata_in_reference_system)
|
||||
assert transform_equal(
|
||||
transformer.forward(position_metadata_in_reference_system), position_metadata_in_given_system
|
||||
)
|
||||
assert transform_equal(
|
||||
transformer.backward(position_metadata_in_given_system), position_metadata_in_reference_system
|
||||
)
|
||||
assert transform_equal(
|
||||
compose(transformer.backward, transformer.forward)(position_metadata_in_reference_system),
|
||||
position_metadata_in_reference_system
|
||||
@ -191,36 +195,6 @@ def get_image_and_page_edge_lengths(base_position_metadata):
|
||||
return __get_w_h_pw_ph
|
||||
|
||||
|
||||
@pytest.fixture(params=[33, 100])
|
||||
def height(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[10, 31])
|
||||
def width(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[220, 30])
|
||||
def page_height(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[100, 310])
|
||||
def page_width(request):
|
||||
return request.param
|
||||
|
||||
|
||||
def get_base_position_metadata(width, height, page_width, page_height):
|
||||
return {
|
||||
Info.WIDTH: width,
|
||||
Info.HEIGHT: height,
|
||||
Info.PAGE_IDX: 0,
|
||||
Info.PAGE_WIDTH: page_width,
|
||||
Info.PAGE_HEIGHT: page_height,
|
||||
}
|
||||
|
||||
|
||||
def get_metadata_coords(x1, y1, x2, y2):
|
||||
return {Info.X1: x1, Info.Y1: y1, Info.X2: x2, Info.Y2: y2}
|
||||
|
||||
|
||||
78
test/unit_tests/image_stitcher_test.py
Normal file
78
test/unit_tests/image_stitcher_test.py
Normal file
@ -0,0 +1,78 @@
|
||||
import json
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
|
||||
import fpdf
|
||||
import pytest
|
||||
from funcy import juxt, merge, rpartial
|
||||
|
||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
from test.conftest import get_base_position_metadata, add_image, random_single_color_image_from_metadata
|
||||
|
||||
|
||||
def test_image_stitcher(partial_image_metadata_pairs):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("width", [100])
|
||||
@pytest.mark.parametrize("height", [300])
|
||||
@pytest.mark.parametrize("page_width", [500])
|
||||
@pytest.mark.parametrize("page_height", [500])
|
||||
def test_partial_image_metadata_pairs(patches_metadata, page_width, page_height):
|
||||
|
||||
pdf = fpdf.FPDF(unit="pt", format=(page_width, page_height))
|
||||
|
||||
for patch in patches_metadata:
|
||||
image = random_single_color_image_from_metadata(patch)
|
||||
add_image(pdf, ImageMetadataPair(image, patch))
|
||||
|
||||
pdf.output("/tmp/bla.pdf")
|
||||
|
||||
|
||||
def split_box(box, max_step=5):
|
||||
def split_recursively(box, step):
|
||||
|
||||
def split_horizontal():
|
||||
return split(Info.WIDTH, Info.X1, Info.X2)
|
||||
|
||||
def split_vertical():
|
||||
return split(Info.HEIGHT, Info.Y1, Info.Y2)
|
||||
|
||||
def split(dim, coord1, coord2):
|
||||
|
||||
if box[dim] >= 10:
|
||||
split_len = random.randint(5, box[dim] - 5)
|
||||
split_point = box[coord1] + split_len
|
||||
|
||||
box_left, box_right = juxt(deepcopy, deepcopy)(box)
|
||||
|
||||
box_left[dim] = split_len
|
||||
box_right[dim] = box[dim] - split_len
|
||||
|
||||
box_left[coord2] = split_point
|
||||
box_right[coord1] = split_point
|
||||
|
||||
return box_left, box_right
|
||||
else:
|
||||
return [box]
|
||||
|
||||
if step < max_step:
|
||||
new_boxes = random.choice([split_horizontal, split_vertical])()
|
||||
|
||||
return chain.from_iterable(map(rpartial(split_recursively, step + 1), new_boxes))
|
||||
else:
|
||||
return [box]
|
||||
|
||||
return split_recursively(box, 0)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def patches_metadata(width, height, page_width, page_height):
|
||||
box = get_base_position_metadata(width, height, page_width, page_height)
|
||||
box = merge(box, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height})
|
||||
boxes = split_box(box)
|
||||
return boxes
|
||||
Loading…
x
Reference in New Issue
Block a user