refactoring

This commit is contained in:
Matthias Bisping 2022-04-04 18:17:49 +02:00
parent e01b5c9acd
commit 38869d52c6
3 changed files with 45 additions and 35 deletions

View File

@ -1,4 +1,7 @@
import abc
from typing import Iterable
from funcy import curry, identity
class Transformer(abc.ABC):
@ -8,3 +11,10 @@ class Transformer(abc.ABC):
def __call__(self, obj):
return self.transform(obj)
@staticmethod
def _must_be_mapped_over(obj):
return isinstance(obj, Iterable) and not isinstance(obj, dict)
def _apply(self, func, obj):
return (curry(map) if self._must_be_mapped_over(obj) else identity)(func)(obj)

View File

@ -13,20 +13,10 @@ class CoordinateTransformer(Transformer):
raise NotImplementedError
def forward(self, metadata):
try:
return self._forward(metadata)
# FIXME: test case not missing?? why is it missing for backward?
except TypeError:
# FIXME: refactor tests so this is not necessary
return list(map(self._forward, metadata))
return self._apply(self._forward, metadata)
def backward(self, metadata):
try:
return self._backward(metadata)
# FIXME: test case missing
except TypeError:
# FIXME: refactor tests so this is not necessary
return list(map(self._backward, metadata))
return self._apply(self._backward, metadata)
def transform(self, metadata):
return self.forward(metadata)

View File

@ -3,7 +3,7 @@ from operator import itemgetter, attrgetter
import numpy as np
import pytest
from fpdf import fpdf
from funcy import compose, omit
from funcy import compose
from pdf2image import pdf2image
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
@ -53,11 +53,11 @@ def test_pdfnet_coordinate_transformer(position_metadata_in_given_system, positi
def test_coordinate_transformer_by_metadata(
transformer, position_metadata_in_given_system, position_metadata_in_reference_system
):
assert transformer.forward(position_metadata_in_reference_system) == position_metadata_in_given_system
assert transformer.backward(position_metadata_in_given_system) == position_metadata_in_reference_system
assert (
compose(transformer.backward, transformer.forward)(position_metadata_in_reference_system)
== position_metadata_in_reference_system
assert equal(transformer.forward(position_metadata_in_reference_system), position_metadata_in_given_system)
assert equal(transformer.backward(position_metadata_in_given_system), position_metadata_in_reference_system)
assert equal(
compose(transformer.backward, transformer.forward)(position_metadata_in_reference_system),
position_metadata_in_reference_system
)
@ -77,13 +77,19 @@ def transformer(coordinate_system):
@pytest.fixture
def position_metadata_in_given_system(corner, corner2metadata_in_given_system):
return corner2metadata_in_given_system[corner]
def position_metadata_in_given_system(corner, corner2metadata_in_given_system, multiple):
metadata = corner2metadata_in_given_system[corner]
return [metadata, metadata] if multiple else metadata
@pytest.fixture
def position_metadata_in_reference_system(corner, corner2metadata_in_reference_system):
return corner2metadata_in_reference_system[corner]
def position_metadata_in_reference_system(corner, corner2metadata_in_reference_system, multiple):
metadata = corner2metadata_in_reference_system[corner]
return [metadata, metadata] if multiple else metadata
def equal(a, b):
return (list(a) if isinstance(a, map) else a) == b
@pytest.fixture(params=["top_left", "bottom_left", "bottom_right", "top_right"])
@ -91,11 +97,6 @@ def corner(request):
return request.param
@pytest.fixture
def corner2metadata_in_reference_system(get_fpdf_corner_metadat):
return get_fpdf_corner_metadat
@pytest.fixture
def corner2metadata_in_given_system(
coordinate_system, get_fpdf_corner_metadat, get_fitz_corner_metadat, get_pdfnet_corner_metadata
@ -113,6 +114,16 @@ def corner2metadata_in_given_system(
raise ValueError(f"Unknown coordinate system: {coordinate_system}")
@pytest.fixture
def corner2metadata_in_reference_system(get_fpdf_corner_metadat):
return get_fpdf_corner_metadat
@pytest.fixture(params=[True, False])
def multiple(request):
return request.param
@pytest.fixture
def get_fpdf_corner_metadat(base_position_metadata, get_metadata_for_coords, get_image_and_page_edge_lengths):
"""Origin top left, y1 <= y2; all coords on page are positive
@ -124,6 +135,7 @@ def get_fpdf_corner_metadat(base_position_metadata, get_metadata_for_coords, get
|////| |////|
+--(1,3) +--(3,3)
"""
# noinspection PyTupleAssignmentBalance
width, height, page_width, page_height = get_image_and_page_edge_lengths()
return {
@ -150,6 +162,7 @@ def get_pdfnet_corner_metadata(base_position_metadata, get_metadata_for_coords,
|////| |////|
(0,0)--+ (2,0)--+
"""
# noinspection PyTupleAssignmentBalance
width, height, page_width, page_height = get_image_and_page_edge_lengths()
return {
@ -169,7 +182,7 @@ def base_position_metadata(width, height, page_width, page_height):
def get_metadata_for_coords(base_position_metadata):
def __get_metadata_for_coords(*coords):
meta_data_coords = get_metadata_coords(*coords)
return {**meta_data_coords, **omit(base_position_metadata, meta_data_coords.keys())}
return {**meta_data_coords, **base_position_metadata}
return __get_metadata_for_coords
@ -204,10 +217,6 @@ def page_width(request):
def get_base_position_metadata(width, height, page_width, page_height):
return {
Info.X1: None,
Info.Y1: None,
Info.X2: None,
Info.Y2: None,
Info.WIDTH: width,
Info.HEIGHT: height,
Info.PAGE_IDX: 0,
@ -221,18 +230,19 @@ def get_metadata_coords(x1, y1, x2, y2):
@pytest.mark.parametrize("coordinate_system", ["pdfnet"])
@pytest.mark.parametrize("multiple", [False])
def test_coordinate_transformer_by_image(
transformer, position_metadata_in_given_system, position_metadata_in_reference_system
):
metadata_transformed = transformer(position_metadata_in_given_system)
target_image = metadata_to_test_image(position_metadata_in_reference_system)
test_image = metadata_to_test_image(metadata_transformed)
target_image = metadata_to_test_page_image(position_metadata_in_reference_system)
test_image = metadata_to_test_page_image(metadata_transformed)
assert np.allclose(target_image, test_image)
def metadata_to_test_image(metadata):
def metadata_to_test_page_image(metadata):
image = get_coordinate_test_image(
*itemgetter(*attrgetter("WIDTH", "HEIGHT")(Info))(metadata)
)