derived enum formatter from key formatter

This commit is contained in:
Matthias Bisping 2022-03-31 17:22:54 +02:00
parent 3eaf9dc0e1
commit e17912caa9
4 changed files with 28 additions and 9 deletions

View File

@ -9,7 +9,7 @@ class Formatter(Transformer):
raise NotImplementedError
def transform(self, obj):
return self.format(obj)
raise NotImplementedError()
def __call__(self, obj):
return self.format(obj)

View File

@ -1,13 +1,10 @@
from enum import Enum
from typing import List
from image_prediction.formatter.formatter import Formatter
from image_prediction.formatter.formatters.key_formatter import KeyFormatter
class EnumFormatter(Formatter):
class EnumFormatter(KeyFormatter):
@staticmethod
def __format(metadata: dict):
return {key.value if isinstance(key, Enum) else key: val for key, val in metadata.items()}
def format(self, metadata: List[dict]):
return map(self.__format, metadata)
def format_key(key):
return key.value if isinstance(key, Enum) else key

View File

@ -225,6 +225,23 @@ def map_labels(numeric_labels, classes):
return [classes[nl] for nl in numeric_labels]
@pytest.fixture
def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata):
return [{"prediction": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
@pytest.fixture
def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted):
return [
{"prediction": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted)
]
@pytest.fixture
def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
return [{k.value: v for k, v in epm.items()} for epm in expected_predictions_mapped]
@pytest.fixture
def metadata(images, info_label_map):
page_idx = 0

View File

@ -4,10 +4,15 @@ from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyF
from image_prediction.formatter.formatters.enum import EnumFormatter
def test_formatter(metadata, metadata_formatted):
def test_enum_formatter(metadata, metadata_formatted):
assert list(EnumFormatter()(metadata)) == metadata_formatted
@pytest.mark.parametrize("label_format", ["probability"])
def test_enum_formatter(metadata_plus_mapped_prediction, metadata_formatted_plus_mapped_prediction_formatted):
assert list(EnumFormatter()(metadata_plus_mapped_prediction)) == metadata_formatted_plus_mapped_prediction_formatted
def test_camel_case_key_formatter(snake_case_data, camel_case_data):
assert Snake2CamelCaseKeyFormatter()(snake_case_data) == camel_case_data