refactoring; added compositor for formatters
This commit is contained in:
parent
0921ef9a4f
commit
dc1cdde458
0
image_prediction/compositor/__init__.py
Normal file
0
image_prediction/compositor/__init__.py
Normal file
13
image_prediction/compositor/compositor.py
Normal file
13
image_prediction/compositor/compositor.py
Normal file
@ -0,0 +1,13 @@
|
||||
from funcy import rcompose
|
||||
|
||||
from image_prediction.formatter.formatter import Formatter
|
||||
|
||||
|
||||
class FormatterCompositor(Formatter):
|
||||
|
||||
def __init__(self, formatter: Formatter, *formatters: Formatter):
|
||||
formatters = (formatter, *formatters)
|
||||
self.pipe = rcompose(*formatters)
|
||||
|
||||
def format(self, obj):
|
||||
return self.pipe(obj)
|
||||
@ -3,5 +3,8 @@ import abc
|
||||
|
||||
class Formatter(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def format(self, info: dict):
|
||||
def format(self, obj):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, obj):
|
||||
return self.format(obj)
|
||||
|
||||
16
image_prediction/formatter/formatters/camel_case.py
Normal file
16
image_prediction/formatter/formatters/camel_case.py
Normal file
@ -0,0 +1,16 @@
|
||||
from funcy import walk_keys, walk
|
||||
|
||||
from image_prediction.formatter.formatter import Formatter
|
||||
|
||||
|
||||
class Snake2CamelCaseKeyFormatter(Formatter):
|
||||
@staticmethod
|
||||
def __format(key: str):
|
||||
if isinstance(key, str):
|
||||
head, *tail = key.split("_")
|
||||
return head + "".join(map(lambda s: s.title(), tail))
|
||||
else:
|
||||
return key
|
||||
|
||||
def format(self, data):
|
||||
return walk(self.__format, data)
|
||||
@ -5,8 +5,9 @@ from image_prediction.formatter.formatter import Formatter
|
||||
|
||||
|
||||
class EnumFormatter(Formatter):
|
||||
def format(self, metadata: dict):
|
||||
@staticmethod
|
||||
def __format(metadata: dict):
|
||||
return {key.value if isinstance(key, Enum) else key: val for key, val in metadata.items()}
|
||||
|
||||
def __call__(self, metadata: List[dict]):
|
||||
return map(self.format, metadata)
|
||||
def format(self, metadata: List[dict]):
|
||||
return map(self.__format, metadata)
|
||||
7
image_prediction/formatter/formatters/identity.py
Normal file
7
image_prediction/formatter/formatters/identity.py
Normal file
@ -0,0 +1,7 @@
|
||||
from image_prediction.formatter.formatter import Formatter
|
||||
|
||||
|
||||
class IdentityFormatter(Formatter):
|
||||
|
||||
def format(self, obj):
|
||||
return obj
|
||||
@ -2,7 +2,7 @@ from operator import itemgetter
|
||||
from typing import Mapping, Iterable
|
||||
|
||||
import numpy as np
|
||||
from funcy import rcompose
|
||||
from funcy import rcompose, rpartial
|
||||
|
||||
from image_prediction.exceptions import UnexpectedLabelFormat
|
||||
from image_prediction.label_mapper.mapper import LabelMapper
|
||||
@ -13,7 +13,7 @@ class ProbabilityMapper(LabelMapper):
|
||||
self.__labels = labels
|
||||
# String conversion in the middle due to floating point precision issues.
|
||||
# See: https://stackoverflow.com/questions/56820/round-doesnt-seem-to-be-rounding-properly
|
||||
self.__rounder = rcompose(lambda d: round(d, 4), str, float)
|
||||
self.__rounder = rcompose(rpartial(round, 4), str, float)
|
||||
|
||||
def __validate_array_label_format(self, probabilities: np.ndarray) -> None:
|
||||
if not len(probabilities) == len(self.__labels):
|
||||
|
||||
15
test/unit_tests/compositor_test.py
Normal file
15
test/unit_tests/compositor_test.py
Normal file
@ -0,0 +1,15 @@
|
||||
import pytest
|
||||
|
||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||
from image_prediction.formatter.formatters.identity import IdentityFormatter
|
||||
from image_prediction.compositor.compositor import FormatterCompositor
|
||||
|
||||
|
||||
def test_single_formatter(metadata):
|
||||
compositor = FormatterCompositor(IdentityFormatter())
|
||||
assert metadata == compositor(metadata)
|
||||
|
||||
|
||||
def test_two_formatters(metadata, metadata_formatted):
|
||||
compositor = FormatterCompositor(IdentityFormatter(), EnumFormatter())
|
||||
assert metadata_formatted == list(compositor(metadata))
|
||||
@ -1,5 +1,22 @@
|
||||
from image_prediction.formatter.formatters.info_formatter import EnumFormatter
|
||||
import pytest
|
||||
|
||||
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
|
||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||
|
||||
|
||||
def test_formatter(metadata, metadata_formatted):
|
||||
assert list(EnumFormatter()(metadata)) == metadata_formatted
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def snake_case_data():
|
||||
return {"a_key": {"key": None, "key_2": ["may_not_be_changed"]}, 2: {"yet_another_key": None}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def camel_case_data():
|
||||
return {"aKey": {"key": None, "key2": ["may_not_be_changed"]}, 2: {"yetAnotherKey": None}}
|
||||
|
||||
|
||||
# def test_camel_case_key_formatter(snake_case_data, camel_case_data):
|
||||
# assert Snake2CamelCaseKeyFormatter()(snake_case_data) == camel_case_data
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user