refactoring; added compositor for formatters

This commit is contained in:
Matthias Bisping 2022-03-31 12:52:15 +02:00
parent 0921ef9a4f
commit dc1cdde458
9 changed files with 79 additions and 7 deletions

View File

View File

@ -0,0 +1,13 @@
from funcy import rcompose
from image_prediction.formatter.formatter import Formatter
class FormatterCompositor(Formatter):
def __init__(self, formatter: Formatter, *formatters: Formatter):
formatters = (formatter, *formatters)
self.pipe = rcompose(*formatters)
def format(self, obj):
return self.pipe(obj)

View File

@ -3,5 +3,8 @@ import abc
class Formatter(abc.ABC):
@abc.abstractmethod
def format(self, info: dict):
def format(self, obj):
raise NotImplementedError
def __call__(self, obj):
return self.format(obj)

View File

@ -0,0 +1,16 @@
from funcy import walk_keys, walk
from image_prediction.formatter.formatter import Formatter
class Snake2CamelCaseKeyFormatter(Formatter):
@staticmethod
def __format(key: str):
if isinstance(key, str):
head, *tail = key.split("_")
return head + "".join(map(lambda s: s.title(), tail))
else:
return key
def format(self, data):
return walk(self.__format, data)

View File

@ -5,8 +5,9 @@ from image_prediction.formatter.formatter import Formatter
class EnumFormatter(Formatter):
def format(self, metadata: dict):
@staticmethod
def __format(metadata: dict):
return {key.value if isinstance(key, Enum) else key: val for key, val in metadata.items()}
def __call__(self, metadata: List[dict]):
return map(self.format, metadata)
def format(self, metadata: List[dict]):
return map(self.__format, metadata)

View File

@ -0,0 +1,7 @@
from image_prediction.formatter.formatter import Formatter
class IdentityFormatter(Formatter):
def format(self, obj):
return obj

View File

@ -2,7 +2,7 @@ from operator import itemgetter
from typing import Mapping, Iterable
import numpy as np
from funcy import rcompose
from funcy import rcompose, rpartial
from image_prediction.exceptions import UnexpectedLabelFormat
from image_prediction.label_mapper.mapper import LabelMapper
@ -13,7 +13,7 @@ class ProbabilityMapper(LabelMapper):
self.__labels = labels
# String conversion in the middle due to floating point precision issues.
# See: https://stackoverflow.com/questions/56820/round-doesnt-seem-to-be-rounding-properly
self.__rounder = rcompose(lambda d: round(d, 4), str, float)
self.__rounder = rcompose(rpartial(round, 4), str, float)
def __validate_array_label_format(self, probabilities: np.ndarray) -> None:
if not len(probabilities) == len(self.__labels):

View File

@ -0,0 +1,15 @@
import pytest
from image_prediction.formatter.formatters.enum import EnumFormatter
from image_prediction.formatter.formatters.identity import IdentityFormatter
from image_prediction.compositor.compositor import FormatterCompositor
def test_single_formatter(metadata):
compositor = FormatterCompositor(IdentityFormatter())
assert metadata == compositor(metadata)
def test_two_formatters(metadata, metadata_formatted):
compositor = FormatterCompositor(IdentityFormatter(), EnumFormatter())
assert metadata_formatted == list(compositor(metadata))

View File

@ -1,5 +1,22 @@
from image_prediction.formatter.formatters.info_formatter import EnumFormatter
import pytest
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
from image_prediction.formatter.formatters.enum import EnumFormatter
def test_formatter(metadata, metadata_formatted):
assert list(EnumFormatter()(metadata)) == metadata_formatted
@pytest.fixture
def snake_case_data():
return {"a_key": {"key": None, "key_2": ["may_not_be_changed"]}, 2: {"yet_another_key": None}}
@pytest.fixture
def camel_case_data():
return {"aKey": {"key": None, "key2": ["may_not_be_changed"]}, 2: {"yetAnotherKey": None}}
# def test_camel_case_key_formatter(snake_case_data, camel_case_data):
# assert Snake2CamelCaseKeyFormatter()(snake_case_data) == camel_case_data