refactoring
This commit is contained in:
parent
e01b5c9acd
commit
38869d52c6
@ -1,4 +1,7 @@
|
||||
import abc
|
||||
from typing import Iterable
|
||||
|
||||
from funcy import curry, identity
|
||||
|
||||
|
||||
class Transformer(abc.ABC):
|
||||
@ -8,3 +11,10 @@ class Transformer(abc.ABC):
|
||||
|
||||
def __call__(self, obj):
|
||||
return self.transform(obj)
|
||||
|
||||
@staticmethod
|
||||
def _must_be_mapped_over(obj):
|
||||
return isinstance(obj, Iterable) and not isinstance(obj, dict)
|
||||
|
||||
def _apply(self, func, obj):
|
||||
return (curry(map) if self._must_be_mapped_over(obj) else identity)(func)(obj)
|
||||
|
||||
@ -13,20 +13,10 @@ class CoordinateTransformer(Transformer):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, metadata):
|
||||
try:
|
||||
return self._forward(metadata)
|
||||
# FIXME: test case not missing?? why is it missing for backward?
|
||||
except TypeError:
|
||||
# FIXME: refactor tests so this is not necessary
|
||||
return list(map(self._forward, metadata))
|
||||
return self._apply(self._forward, metadata)
|
||||
|
||||
def backward(self, metadata):
|
||||
try:
|
||||
return self._backward(metadata)
|
||||
# FIXME: test case missing
|
||||
except TypeError:
|
||||
# FIXME: refactor tests so this is not necessary
|
||||
return list(map(self._backward, metadata))
|
||||
return self._apply(self._backward, metadata)
|
||||
|
||||
def transform(self, metadata):
|
||||
return self.forward(metadata)
|
||||
|
||||
@ -3,7 +3,7 @@ from operator import itemgetter, attrgetter
|
||||
import numpy as np
|
||||
import pytest
|
||||
from fpdf import fpdf
|
||||
from funcy import compose, omit
|
||||
from funcy import compose
|
||||
from pdf2image import pdf2image
|
||||
|
||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
||||
@ -53,11 +53,11 @@ def test_pdfnet_coordinate_transformer(position_metadata_in_given_system, positi
|
||||
def test_coordinate_transformer_by_metadata(
|
||||
transformer, position_metadata_in_given_system, position_metadata_in_reference_system
|
||||
):
|
||||
assert transformer.forward(position_metadata_in_reference_system) == position_metadata_in_given_system
|
||||
assert transformer.backward(position_metadata_in_given_system) == position_metadata_in_reference_system
|
||||
assert (
|
||||
compose(transformer.backward, transformer.forward)(position_metadata_in_reference_system)
|
||||
== position_metadata_in_reference_system
|
||||
assert equal(transformer.forward(position_metadata_in_reference_system), position_metadata_in_given_system)
|
||||
assert equal(transformer.backward(position_metadata_in_given_system), position_metadata_in_reference_system)
|
||||
assert equal(
|
||||
compose(transformer.backward, transformer.forward)(position_metadata_in_reference_system),
|
||||
position_metadata_in_reference_system
|
||||
)
|
||||
|
||||
|
||||
@ -77,13 +77,19 @@ def transformer(coordinate_system):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def position_metadata_in_given_system(corner, corner2metadata_in_given_system):
|
||||
return corner2metadata_in_given_system[corner]
|
||||
def position_metadata_in_given_system(corner, corner2metadata_in_given_system, multiple):
|
||||
metadata = corner2metadata_in_given_system[corner]
|
||||
return [metadata, metadata] if multiple else metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def position_metadata_in_reference_system(corner, corner2metadata_in_reference_system):
|
||||
return corner2metadata_in_reference_system[corner]
|
||||
def position_metadata_in_reference_system(corner, corner2metadata_in_reference_system, multiple):
|
||||
metadata = corner2metadata_in_reference_system[corner]
|
||||
return [metadata, metadata] if multiple else metadata
|
||||
|
||||
|
||||
def equal(a, b):
|
||||
return (list(a) if isinstance(a, map) else a) == b
|
||||
|
||||
|
||||
@pytest.fixture(params=["top_left", "bottom_left", "bottom_right", "top_right"])
|
||||
@ -91,11 +97,6 @@ def corner(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def corner2metadata_in_reference_system(get_fpdf_corner_metadat):
|
||||
return get_fpdf_corner_metadat
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def corner2metadata_in_given_system(
|
||||
coordinate_system, get_fpdf_corner_metadat, get_fitz_corner_metadat, get_pdfnet_corner_metadata
|
||||
@ -113,6 +114,16 @@ def corner2metadata_in_given_system(
|
||||
raise ValueError(f"Unknown coordinate system: {coordinate_system}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def corner2metadata_in_reference_system(get_fpdf_corner_metadat):
|
||||
return get_fpdf_corner_metadat
|
||||
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def multiple(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_fpdf_corner_metadat(base_position_metadata, get_metadata_for_coords, get_image_and_page_edge_lengths):
|
||||
"""Origin top left, y1 <= y2; all coords on page are positive
|
||||
@ -124,6 +135,7 @@ def get_fpdf_corner_metadat(base_position_metadata, get_metadata_for_coords, get
|
||||
|////| |////|
|
||||
+--(1,3) +--(3,3)
|
||||
"""
|
||||
# noinspection PyTupleAssignmentBalance
|
||||
width, height, page_width, page_height = get_image_and_page_edge_lengths()
|
||||
|
||||
return {
|
||||
@ -150,6 +162,7 @@ def get_pdfnet_corner_metadata(base_position_metadata, get_metadata_for_coords,
|
||||
|////| |////|
|
||||
(0,0)--+ (2,0)--+
|
||||
"""
|
||||
# noinspection PyTupleAssignmentBalance
|
||||
width, height, page_width, page_height = get_image_and_page_edge_lengths()
|
||||
|
||||
return {
|
||||
@ -169,7 +182,7 @@ def base_position_metadata(width, height, page_width, page_height):
|
||||
def get_metadata_for_coords(base_position_metadata):
|
||||
def __get_metadata_for_coords(*coords):
|
||||
meta_data_coords = get_metadata_coords(*coords)
|
||||
return {**meta_data_coords, **omit(base_position_metadata, meta_data_coords.keys())}
|
||||
return {**meta_data_coords, **base_position_metadata}
|
||||
|
||||
return __get_metadata_for_coords
|
||||
|
||||
@ -204,10 +217,6 @@ def page_width(request):
|
||||
|
||||
def get_base_position_metadata(width, height, page_width, page_height):
|
||||
return {
|
||||
Info.X1: None,
|
||||
Info.Y1: None,
|
||||
Info.X2: None,
|
||||
Info.Y2: None,
|
||||
Info.WIDTH: width,
|
||||
Info.HEIGHT: height,
|
||||
Info.PAGE_IDX: 0,
|
||||
@ -221,18 +230,19 @@ def get_metadata_coords(x1, y1, x2, y2):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("coordinate_system", ["pdfnet"])
|
||||
@pytest.mark.parametrize("multiple", [False])
|
||||
def test_coordinate_transformer_by_image(
|
||||
transformer, position_metadata_in_given_system, position_metadata_in_reference_system
|
||||
):
|
||||
metadata_transformed = transformer(position_metadata_in_given_system)
|
||||
|
||||
target_image = metadata_to_test_image(position_metadata_in_reference_system)
|
||||
test_image = metadata_to_test_image(metadata_transformed)
|
||||
target_image = metadata_to_test_page_image(position_metadata_in_reference_system)
|
||||
test_image = metadata_to_test_page_image(metadata_transformed)
|
||||
|
||||
assert np.allclose(target_image, test_image)
|
||||
|
||||
|
||||
def metadata_to_test_image(metadata):
|
||||
def metadata_to_test_page_image(metadata):
|
||||
image = get_coordinate_test_image(
|
||||
*itemgetter(*attrgetter("WIDTH", "HEIGHT")(Info))(metadata)
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user