diff --git a/image_prediction/formatter/formatter.py b/image_prediction/formatter/formatter.py index 6c5b97f..3f3a1f8 100644 --- a/image_prediction/formatter/formatter.py +++ b/image_prediction/formatter/formatter.py @@ -9,7 +9,7 @@ class Formatter(Transformer): raise NotImplementedError def transform(self, obj): - return self.format(obj) + raise NotImplementedError() def __call__(self, obj): return self.format(obj) diff --git a/image_prediction/formatter/formatters/enum.py b/image_prediction/formatter/formatters/enum.py index f719ead..a412b31 100644 --- a/image_prediction/formatter/formatters/enum.py +++ b/image_prediction/formatter/formatters/enum.py @@ -1,13 +1,10 @@ from enum import Enum from typing import List -from image_prediction.formatter.formatter import Formatter +from image_prediction.formatter.formatters.key_formatter import KeyFormatter -class EnumFormatter(Formatter): +class EnumFormatter(KeyFormatter): @staticmethod - def __format(metadata: dict): - return {key.value if isinstance(key, Enum) else key: val for key, val in metadata.items()} - - def format(self, metadata: List[dict]): - return map(self.__format, metadata) + def format_key(key): + return key.value if isinstance(key, Enum) else key diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index 9f7c5c7..dc31663 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -225,6 +225,23 @@ def map_labels(numeric_labels, classes): return [classes[nl] for nl in numeric_labels] +@pytest.fixture +def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata): + return [{"prediction": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)] + + +@pytest.fixture +def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted): + return [ + {"prediction": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted) + ] + + +@pytest.fixture +def expected_predictions_mapped_and_formatted(expected_predictions_mapped): + return [{k.value: v for k, v in epm.items()} for epm in expected_predictions_mapped] + + @pytest.fixture def metadata(images, info_label_map): page_idx = 0 diff --git a/test/unit_tests/formatter_test.py b/test/unit_tests/formatter_test.py index c222147..b012b9a 100644 --- a/test/unit_tests/formatter_test.py +++ b/test/unit_tests/formatter_test.py @@ -4,10 +4,15 @@ from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyF from image_prediction.formatter.formatters.enum import EnumFormatter -def test_formatter(metadata, metadata_formatted): +def test_enum_formatter(metadata, metadata_formatted): assert list(EnumFormatter()(metadata)) == metadata_formatted +@pytest.mark.parametrize("label_format", ["probability"]) +def test_enum_formatter(metadata_plus_mapped_prediction, metadata_formatted_plus_mapped_prediction_formatted): + assert list(EnumFormatter()(metadata_plus_mapped_prediction)) == metadata_formatted_plus_mapped_prediction_formatted + + def test_camel_case_key_formatter(snake_case_data, camel_case_data): assert Snake2CamelCaseKeyFormatter()(snake_case_data) == camel_case_data