refactoring
This commit is contained in:
parent
38869d52c6
commit
692e72b3b2
@ -10,7 +10,7 @@ class Transformer(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, obj):
|
||||
return self.transform(obj)
|
||||
return self._apply(self.transform, obj)
|
||||
|
||||
@staticmethod
|
||||
def _must_be_mapped_over(obj):
|
||||
|
||||
@ -8,14 +8,10 @@ from image_prediction.utils import get_logger
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
# TODO: move to transformers
|
||||
class ResponseTransformer(Transformer):
|
||||
def transform(self, data):
|
||||
logger.debug("ResponseTransformer.transform")
|
||||
try:
|
||||
return build_image_info(data)
|
||||
except TypeError:
|
||||
return map(build_image_info, data)
|
||||
return build_image_info(data)
|
||||
|
||||
|
||||
def build_image_info(data: dict) -> dict:
|
||||
|
||||
@ -444,3 +444,7 @@ def real_expected_service_response():
|
||||
def pipeline():
|
||||
pipeline = load_pipeline(verbose=False)
|
||||
return pipeline
|
||||
|
||||
|
||||
def transform_equal(a, b):
|
||||
return (list(a) if isinstance(a, map) else a) == b
|
||||
@ -4,16 +4,17 @@ from image_prediction.compositor.compositor import TransformerCompositor
|
||||
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
|
||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||
from image_prediction.formatter.formatters.identity import IdentityFormatter
|
||||
from test.conftest import transform_equal
|
||||
|
||||
|
||||
def test_identity(metadata):
|
||||
compositor = TransformerCompositor(IdentityFormatter())
|
||||
assert compositor(metadata) == metadata
|
||||
assert transform_equal(compositor(metadata), metadata)
|
||||
|
||||
|
||||
def test_composition(metadata, metadata_formatted):
|
||||
compositor = TransformerCompositor(IdentityFormatter(), EnumFormatter())
|
||||
assert metadata_formatted == list(compositor(metadata))
|
||||
assert transform_equal(compositor(metadata), metadata_formatted)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -28,4 +29,4 @@ def compositor_test_camel_case_metadata(info_label_map):
|
||||
|
||||
def test_enum_to_camel_case(compositor_test_enum_metadata, compositor_test_camel_case_metadata):
|
||||
compositor = TransformerCompositor(EnumFormatter(), Snake2CamelCaseKeyFormatter())
|
||||
assert list(compositor(compositor_test_enum_metadata)) == compositor_test_camel_case_metadata
|
||||
assert transform_equal(compositor(compositor_test_enum_metadata), compositor_test_camel_case_metadata)
|
||||
|
||||
@ -12,7 +12,7 @@ from image_prediction.info import Info
|
||||
from image_prediction.transformer.transformers.coordinate.fitz import FitzCoordinateTransformer
|
||||
from image_prediction.transformer.transformers.coordinate.fpdf import FPDFCoordinateTransformer
|
||||
from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer
|
||||
from test.conftest import array_to_image, add_image
|
||||
from test.conftest import array_to_image, add_image, transform_equal
|
||||
|
||||
|
||||
@pytest.mark.parametrize("coordinate_system", ["fpdf"])
|
||||
@ -53,9 +53,9 @@ def test_pdfnet_coordinate_transformer(position_metadata_in_given_system, positi
|
||||
def test_coordinate_transformer_by_metadata(
|
||||
transformer, position_metadata_in_given_system, position_metadata_in_reference_system
|
||||
):
|
||||
assert equal(transformer.forward(position_metadata_in_reference_system), position_metadata_in_given_system)
|
||||
assert equal(transformer.backward(position_metadata_in_given_system), position_metadata_in_reference_system)
|
||||
assert equal(
|
||||
assert transform_equal(transformer.forward(position_metadata_in_reference_system), position_metadata_in_given_system)
|
||||
assert transform_equal(transformer.backward(position_metadata_in_given_system), position_metadata_in_reference_system)
|
||||
assert transform_equal(
|
||||
compose(transformer.backward, transformer.forward)(position_metadata_in_reference_system),
|
||||
position_metadata_in_reference_system
|
||||
)
|
||||
@ -88,10 +88,6 @@ def position_metadata_in_reference_system(corner, corner2metadata_in_reference_s
|
||||
return [metadata, metadata] if multiple else metadata
|
||||
|
||||
|
||||
def equal(a, b):
|
||||
return (list(a) if isinstance(a, map) else a) == b
|
||||
|
||||
|
||||
@pytest.fixture(params=["top_left", "bottom_left", "bottom_right", "top_right"])
|
||||
def corner(request):
|
||||
return request.param
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user