diff --git a/image_prediction/transformer/transformer.py b/image_prediction/transformer/transformer.py index c079c68..c23d42f 100644 --- a/image_prediction/transformer/transformer.py +++ b/image_prediction/transformer/transformer.py @@ -1,4 +1,7 @@ import abc +from typing import Iterable + +from funcy import curry, identity class Transformer(abc.ABC): @@ -8,3 +11,10 @@ class Transformer(abc.ABC): def __call__(self, obj): return self.transform(obj) + + @staticmethod + def _must_be_mapped_over(obj): + return isinstance(obj, Iterable) and not isinstance(obj, dict) + + def _apply(self, func, obj): + return (curry(map) if self._must_be_mapped_over(obj) else identity)(func)(obj) diff --git a/image_prediction/transformer/transformers/coordinate/coordinate_transformer.py b/image_prediction/transformer/transformers/coordinate/coordinate_transformer.py index d208d87..d72fc2a 100644 --- a/image_prediction/transformer/transformers/coordinate/coordinate_transformer.py +++ b/image_prediction/transformer/transformers/coordinate/coordinate_transformer.py @@ -13,20 +13,10 @@ class CoordinateTransformer(Transformer): raise NotImplementedError def forward(self, metadata): - try: - return self._forward(metadata) - # FIXME: test case not missing?? why is it missing for backward? - except TypeError: - # FIXME: refactor tests so this is not necessary - return list(map(self._forward, metadata)) + return self._apply(self._forward, metadata) def backward(self, metadata): - try: - return self._backward(metadata) - # FIXME: test case missing - except TypeError: - # FIXME: refactor tests so this is not necessary - return list(map(self._backward, metadata)) + return self._apply(self._backward, metadata) def transform(self, metadata): return self.forward(metadata) diff --git a/test/unit_tests/coordinate_transformer_test.py b/test/unit_tests/coordinate_transformer_test.py index d566d81..8215fe3 100644 --- a/test/unit_tests/coordinate_transformer_test.py +++ b/test/unit_tests/coordinate_transformer_test.py @@ -3,7 +3,7 @@ from operator import itemgetter, attrgetter import numpy as np import pytest from fpdf import fpdf -from funcy import compose, omit +from funcy import compose from pdf2image import pdf2image from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor @@ -53,11 +53,11 @@ def test_pdfnet_coordinate_transformer(position_metadata_in_given_system, positi def test_coordinate_transformer_by_metadata( transformer, position_metadata_in_given_system, position_metadata_in_reference_system ): - assert transformer.forward(position_metadata_in_reference_system) == position_metadata_in_given_system - assert transformer.backward(position_metadata_in_given_system) == position_metadata_in_reference_system - assert ( - compose(transformer.backward, transformer.forward)(position_metadata_in_reference_system) - == position_metadata_in_reference_system + assert equal(transformer.forward(position_metadata_in_reference_system), position_metadata_in_given_system) + assert equal(transformer.backward(position_metadata_in_given_system), position_metadata_in_reference_system) + assert equal( + compose(transformer.backward, transformer.forward)(position_metadata_in_reference_system), + position_metadata_in_reference_system ) @@ -77,13 +77,19 @@ def transformer(coordinate_system): @pytest.fixture -def position_metadata_in_given_system(corner, corner2metadata_in_given_system): - return corner2metadata_in_given_system[corner] +def position_metadata_in_given_system(corner, corner2metadata_in_given_system, multiple): + metadata = corner2metadata_in_given_system[corner] + return [metadata, metadata] if multiple else metadata @pytest.fixture -def position_metadata_in_reference_system(corner, corner2metadata_in_reference_system): - return corner2metadata_in_reference_system[corner] +def position_metadata_in_reference_system(corner, corner2metadata_in_reference_system, multiple): + metadata = corner2metadata_in_reference_system[corner] + return [metadata, metadata] if multiple else metadata + + +def equal(a, b): + return (list(a) if isinstance(a, map) else a) == b @pytest.fixture(params=["top_left", "bottom_left", "bottom_right", "top_right"]) @@ -91,11 +97,6 @@ def corner(request): return request.param -@pytest.fixture -def corner2metadata_in_reference_system(get_fpdf_corner_metadat): - return get_fpdf_corner_metadat - - @pytest.fixture def corner2metadata_in_given_system( coordinate_system, get_fpdf_corner_metadat, get_fitz_corner_metadat, get_pdfnet_corner_metadata @@ -113,6 +114,16 @@ def corner2metadata_in_given_system( raise ValueError(f"Unknown coordinate system: {coordinate_system}") +@pytest.fixture +def corner2metadata_in_reference_system(get_fpdf_corner_metadat): + return get_fpdf_corner_metadat + + +@pytest.fixture(params=[True, False]) +def multiple(request): + return request.param + + @pytest.fixture def get_fpdf_corner_metadat(base_position_metadata, get_metadata_for_coords, get_image_and_page_edge_lengths): """Origin top left, y1 <= y2; all coords on page are positive @@ -124,6 +135,7 @@ def get_fpdf_corner_metadat(base_position_metadata, get_metadata_for_coords, get |////| |////| +--(1,3) +--(3,3) """ + # noinspection PyTupleAssignmentBalance width, height, page_width, page_height = get_image_and_page_edge_lengths() return { @@ -150,6 +162,7 @@ def get_pdfnet_corner_metadata(base_position_metadata, get_metadata_for_coords, |////| |////| (0,0)--+ (2,0)--+ """ + # noinspection PyTupleAssignmentBalance width, height, page_width, page_height = get_image_and_page_edge_lengths() return { @@ -169,7 +182,7 @@ def base_position_metadata(width, height, page_width, page_height): def get_metadata_for_coords(base_position_metadata): def __get_metadata_for_coords(*coords): meta_data_coords = get_metadata_coords(*coords) - return {**meta_data_coords, **omit(base_position_metadata, meta_data_coords.keys())} + return {**meta_data_coords, **base_position_metadata} return __get_metadata_for_coords @@ -204,10 +217,6 @@ def page_width(request): def get_base_position_metadata(width, height, page_width, page_height): return { - Info.X1: None, - Info.Y1: None, - Info.X2: None, - Info.Y2: None, Info.WIDTH: width, Info.HEIGHT: height, Info.PAGE_IDX: 0, @@ -221,18 +230,19 @@ def get_metadata_coords(x1, y1, x2, y2): @pytest.mark.parametrize("coordinate_system", ["pdfnet"]) +@pytest.mark.parametrize("multiple", [False]) def test_coordinate_transformer_by_image( transformer, position_metadata_in_given_system, position_metadata_in_reference_system ): metadata_transformed = transformer(position_metadata_in_given_system) - target_image = metadata_to_test_image(position_metadata_in_reference_system) - test_image = metadata_to_test_image(metadata_transformed) + target_image = metadata_to_test_page_image(position_metadata_in_reference_system) + test_image = metadata_to_test_page_image(metadata_transformed) assert np.allclose(target_image, test_image) -def metadata_to_test_image(metadata): +def metadata_to_test_page_image(metadata): image = get_coordinate_test_image( *itemgetter(*attrgetter("WIDTH", "HEIGHT")(Info))(metadata) )