derived enum formatter from key formatter
This commit is contained in:
parent
3eaf9dc0e1
commit
e17912caa9
@ -9,7 +9,7 @@ class Formatter(Transformer):
|
||||
raise NotImplementedError
|
||||
|
||||
def transform(self, obj):
|
||||
return self.format(obj)
|
||||
raise NotImplementedError()
|
||||
|
||||
def __call__(self, obj):
|
||||
return self.format(obj)
|
||||
|
||||
@ -1,13 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from image_prediction.formatter.formatter import Formatter
|
||||
from image_prediction.formatter.formatters.key_formatter import KeyFormatter
|
||||
|
||||
|
||||
class EnumFormatter(Formatter):
|
||||
class EnumFormatter(KeyFormatter):
|
||||
@staticmethod
|
||||
def __format(metadata: dict):
|
||||
return {key.value if isinstance(key, Enum) else key: val for key, val in metadata.items()}
|
||||
|
||||
def format(self, metadata: List[dict]):
|
||||
return map(self.__format, metadata)
|
||||
def format_key(key):
|
||||
return key.value if isinstance(key, Enum) else key
|
||||
|
||||
@ -225,6 +225,23 @@ def map_labels(numeric_labels, classes):
|
||||
return [classes[nl] for nl in numeric_labels]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata):
|
||||
return [{"prediction": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted):
|
||||
return [
|
||||
{"prediction": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
|
||||
return [{k.value: v for k, v in epm.items()} for epm in expected_predictions_mapped]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadata(images, info_label_map):
|
||||
page_idx = 0
|
||||
|
||||
@ -4,10 +4,15 @@ from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyF
|
||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||
|
||||
|
||||
def test_formatter(metadata, metadata_formatted):
|
||||
def test_enum_formatter(metadata, metadata_formatted):
|
||||
assert list(EnumFormatter()(metadata)) == metadata_formatted
|
||||
|
||||
|
||||
@pytest.mark.parametrize("label_format", ["probability"])
|
||||
def test_enum_formatter(metadata_plus_mapped_prediction, metadata_formatted_plus_mapped_prediction_formatted):
|
||||
assert list(EnumFormatter()(metadata_plus_mapped_prediction)) == metadata_formatted_plus_mapped_prediction_formatted
|
||||
|
||||
|
||||
def test_camel_case_key_formatter(snake_case_data, camel_case_data):
|
||||
assert Snake2CamelCaseKeyFormatter()(snake_case_data) == camel_case_data
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user