diff --git a/image_prediction/compositor/__init__.py b/image_prediction/compositor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/compositor/compositor.py b/image_prediction/compositor/compositor.py new file mode 100644 index 0000000..cb39ea3 --- /dev/null +++ b/image_prediction/compositor/compositor.py @@ -0,0 +1,13 @@ +from funcy import rcompose + +from image_prediction.formatter.formatter import Formatter + + +class FormatterCompositor(Formatter): + + def __init__(self, formatter: Formatter, *formatters: Formatter): + formatters = (formatter, *formatters) + self.pipe = rcompose(*formatters) + + def format(self, obj): + return self.pipe(obj) diff --git a/image_prediction/formatter/formatter.py b/image_prediction/formatter/formatter.py index ab1688b..2129813 100644 --- a/image_prediction/formatter/formatter.py +++ b/image_prediction/formatter/formatter.py @@ -3,5 +3,8 @@ import abc class Formatter(abc.ABC): @abc.abstractmethod - def format(self, info: dict): + def format(self, obj): raise NotImplementedError + + def __call__(self, obj): + return self.format(obj) diff --git a/image_prediction/formatter/formatters/camel_case.py b/image_prediction/formatter/formatters/camel_case.py new file mode 100644 index 0000000..668659d --- /dev/null +++ b/image_prediction/formatter/formatters/camel_case.py @@ -0,0 +1,16 @@ +from funcy import walk_keys, walk + +from image_prediction.formatter.formatter import Formatter + + +class Snake2CamelCaseKeyFormatter(Formatter): + @staticmethod + def __format(key: str): + if isinstance(key, str): + head, *tail = key.split("_") + return head + "".join(map(lambda s: s.title(), tail)) + else: + return key + + def format(self, data): + return walk(self.__format, data) diff --git a/image_prediction/formatter/formatters/info_formatter.py b/image_prediction/formatter/formatters/enum.py similarity index 63% rename from image_prediction/formatter/formatters/info_formatter.py rename to image_prediction/formatter/formatters/enum.py index dfacd58..f719ead 100644 --- a/image_prediction/formatter/formatters/info_formatter.py +++ b/image_prediction/formatter/formatters/enum.py @@ -5,8 +5,9 @@ from image_prediction.formatter.formatter import Formatter class EnumFormatter(Formatter): - def format(self, metadata: dict): + @staticmethod + def __format(metadata: dict): return {key.value if isinstance(key, Enum) else key: val for key, val in metadata.items()} - def __call__(self, metadata: List[dict]): - return map(self.format, metadata) + def format(self, metadata: List[dict]): + return map(self.__format, metadata) diff --git a/image_prediction/formatter/formatters/identity.py b/image_prediction/formatter/formatters/identity.py new file mode 100644 index 0000000..8443f02 --- /dev/null +++ b/image_prediction/formatter/formatters/identity.py @@ -0,0 +1,7 @@ +from image_prediction.formatter.formatter import Formatter + + +class IdentityFormatter(Formatter): + + def format(self, obj): + return obj diff --git a/image_prediction/label_mapper/mappers/probability.py b/image_prediction/label_mapper/mappers/probability.py index 9e7347c..66825bb 100644 --- a/image_prediction/label_mapper/mappers/probability.py +++ b/image_prediction/label_mapper/mappers/probability.py @@ -2,7 +2,7 @@ from operator import itemgetter from typing import Mapping, Iterable import numpy as np -from funcy import rcompose +from funcy import rcompose, rpartial from image_prediction.exceptions import UnexpectedLabelFormat from image_prediction.label_mapper.mapper import LabelMapper @@ -13,7 +13,7 @@ class ProbabilityMapper(LabelMapper): self.__labels = labels # String conversion in the middle due to floating point precision issues. # See: https://stackoverflow.com/questions/56820/round-doesnt-seem-to-be-rounding-properly - self.__rounder = rcompose(lambda d: round(d, 4), str, float) + self.__rounder = rcompose(rpartial(round, 4), str, float) def __validate_array_label_format(self, probabilities: np.ndarray) -> None: if not len(probabilities) == len(self.__labels): diff --git a/test/unit_tests/compositor_test.py b/test/unit_tests/compositor_test.py new file mode 100644 index 0000000..88316c9 --- /dev/null +++ b/test/unit_tests/compositor_test.py @@ -0,0 +1,15 @@ +import pytest + +from image_prediction.formatter.formatters.enum import EnumFormatter +from image_prediction.formatter.formatters.identity import IdentityFormatter +from image_prediction.compositor.compositor import FormatterCompositor + + +def test_single_formatter(metadata): + compositor = FormatterCompositor(IdentityFormatter()) + assert metadata == compositor(metadata) + + +def test_two_formatters(metadata, metadata_formatted): + compositor = FormatterCompositor(IdentityFormatter(), EnumFormatter()) + assert metadata_formatted == list(compositor(metadata)) diff --git a/test/unit_tests/formatter_test.py b/test/unit_tests/formatter_test.py index a8764eb..2e1d0e4 100644 --- a/test/unit_tests/formatter_test.py +++ b/test/unit_tests/formatter_test.py @@ -1,5 +1,22 @@ -from image_prediction.formatter.formatters.info_formatter import EnumFormatter +import pytest + +from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter +from image_prediction.formatter.formatters.enum import EnumFormatter def test_formatter(metadata, metadata_formatted): assert list(EnumFormatter()(metadata)) == metadata_formatted + + +@pytest.fixture +def snake_case_data(): + return {"a_key": {"key": None, "key_2": ["may_not_be_changed"]}, 2: {"yet_another_key": None}} + + +@pytest.fixture +def camel_case_data(): + return {"aKey": {"key": None, "key2": ["may_not_be_changed"]}, 2: {"yetAnotherKey": None}} + + +# def test_camel_case_key_formatter(snake_case_data, camel_case_data): +# assert Snake2CamelCaseKeyFormatter()(snake_case_data) == camel_case_data