refactoring; added compositor for formatters

This commit is contained in:
Matthias Bisping 2022-03-31 12:52:15 +02:00
parent 0921ef9a4f
commit dc1cdde458
9 changed files with 79 additions and 7 deletions

View File

View File

@ -0,0 +1,13 @@
from funcy import rcompose
from image_prediction.formatter.formatter import Formatter
class FormatterCompositor(Formatter):
def __init__(self, formatter: Formatter, *formatters: Formatter):
formatters = (formatter, *formatters)
self.pipe = rcompose(*formatters)
def format(self, obj):
return self.pipe(obj)

View File

@ -3,5 +3,8 @@ import abc
class Formatter(abc.ABC): class Formatter(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def format(self, info: dict): def format(self, obj):
raise NotImplementedError raise NotImplementedError
def __call__(self, obj):
return self.format(obj)

View File

@ -0,0 +1,16 @@
from funcy import walk_keys, walk
from image_prediction.formatter.formatter import Formatter
class Snake2CamelCaseKeyFormatter(Formatter):
@staticmethod
def __format(key: str):
if isinstance(key, str):
head, *tail = key.split("_")
return head + "".join(map(lambda s: s.title(), tail))
else:
return key
def format(self, data):
return walk(self.__format, data)

View File

@ -5,8 +5,9 @@ from image_prediction.formatter.formatter import Formatter
class EnumFormatter(Formatter): class EnumFormatter(Formatter):
def format(self, metadata: dict): @staticmethod
def __format(metadata: dict):
return {key.value if isinstance(key, Enum) else key: val for key, val in metadata.items()} return {key.value if isinstance(key, Enum) else key: val for key, val in metadata.items()}
def __call__(self, metadata: List[dict]): def format(self, metadata: List[dict]):
return map(self.format, metadata) return map(self.__format, metadata)

View File

@ -0,0 +1,7 @@
from image_prediction.formatter.formatter import Formatter
class IdentityFormatter(Formatter):
def format(self, obj):
return obj

View File

@ -2,7 +2,7 @@ from operator import itemgetter
from typing import Mapping, Iterable from typing import Mapping, Iterable
import numpy as np import numpy as np
from funcy import rcompose from funcy import rcompose, rpartial
from image_prediction.exceptions import UnexpectedLabelFormat from image_prediction.exceptions import UnexpectedLabelFormat
from image_prediction.label_mapper.mapper import LabelMapper from image_prediction.label_mapper.mapper import LabelMapper
@ -13,7 +13,7 @@ class ProbabilityMapper(LabelMapper):
self.__labels = labels self.__labels = labels
# String conversion in the middle due to floating point precision issues. # String conversion in the middle due to floating point precision issues.
# See: https://stackoverflow.com/questions/56820/round-doesnt-seem-to-be-rounding-properly # See: https://stackoverflow.com/questions/56820/round-doesnt-seem-to-be-rounding-properly
self.__rounder = rcompose(lambda d: round(d, 4), str, float) self.__rounder = rcompose(rpartial(round, 4), str, float)
def __validate_array_label_format(self, probabilities: np.ndarray) -> None: def __validate_array_label_format(self, probabilities: np.ndarray) -> None:
if not len(probabilities) == len(self.__labels): if not len(probabilities) == len(self.__labels):

View File

@ -0,0 +1,15 @@
import pytest
from image_prediction.formatter.formatters.enum import EnumFormatter
from image_prediction.formatter.formatters.identity import IdentityFormatter
from image_prediction.compositor.compositor import FormatterCompositor
def test_single_formatter(metadata):
compositor = FormatterCompositor(IdentityFormatter())
assert metadata == compositor(metadata)
def test_two_formatters(metadata, metadata_formatted):
compositor = FormatterCompositor(IdentityFormatter(), EnumFormatter())
assert metadata_formatted == list(compositor(metadata))

View File

@ -1,5 +1,22 @@
from image_prediction.formatter.formatters.info_formatter import EnumFormatter import pytest
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
from image_prediction.formatter.formatters.enum import EnumFormatter
def test_formatter(metadata, metadata_formatted): def test_formatter(metadata, metadata_formatted):
assert list(EnumFormatter()(metadata)) == metadata_formatted assert list(EnumFormatter()(metadata)) == metadata_formatted
@pytest.fixture
def snake_case_data():
return {"a_key": {"key": None, "key_2": ["may_not_be_changed"]}, 2: {"yet_another_key": None}}
@pytest.fixture
def camel_case_data():
return {"aKey": {"key": None, "key2": ["may_not_be_changed"]}, 2: {"yetAnotherKey": None}}
# def test_camel_case_key_formatter(snake_case_data, camel_case_data):
# assert Snake2CamelCaseKeyFormatter()(snake_case_data) == camel_case_data