refactoring; added compositor for formatters
This commit is contained in:
parent
0921ef9a4f
commit
dc1cdde458
0
image_prediction/compositor/__init__.py
Normal file
0
image_prediction/compositor/__init__.py
Normal file
13
image_prediction/compositor/compositor.py
Normal file
13
image_prediction/compositor/compositor.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from funcy import rcompose
|
||||||
|
|
||||||
|
from image_prediction.formatter.formatter import Formatter
|
||||||
|
|
||||||
|
|
||||||
|
class FormatterCompositor(Formatter):
|
||||||
|
|
||||||
|
def __init__(self, formatter: Formatter, *formatters: Formatter):
|
||||||
|
formatters = (formatter, *formatters)
|
||||||
|
self.pipe = rcompose(*formatters)
|
||||||
|
|
||||||
|
def format(self, obj):
|
||||||
|
return self.pipe(obj)
|
||||||
@ -3,5 +3,8 @@ import abc
|
|||||||
|
|
||||||
class Formatter(abc.ABC):
|
class Formatter(abc.ABC):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def format(self, info: dict):
|
def format(self, obj):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, obj):
|
||||||
|
return self.format(obj)
|
||||||
|
|||||||
16
image_prediction/formatter/formatters/camel_case.py
Normal file
16
image_prediction/formatter/formatters/camel_case.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from funcy import walk_keys, walk
|
||||||
|
|
||||||
|
from image_prediction.formatter.formatter import Formatter
|
||||||
|
|
||||||
|
|
||||||
|
class Snake2CamelCaseKeyFormatter(Formatter):
|
||||||
|
@staticmethod
|
||||||
|
def __format(key: str):
|
||||||
|
if isinstance(key, str):
|
||||||
|
head, *tail = key.split("_")
|
||||||
|
return head + "".join(map(lambda s: s.title(), tail))
|
||||||
|
else:
|
||||||
|
return key
|
||||||
|
|
||||||
|
def format(self, data):
|
||||||
|
return walk(self.__format, data)
|
||||||
@ -5,8 +5,9 @@ from image_prediction.formatter.formatter import Formatter
|
|||||||
|
|
||||||
|
|
||||||
class EnumFormatter(Formatter):
|
class EnumFormatter(Formatter):
|
||||||
def format(self, metadata: dict):
|
@staticmethod
|
||||||
|
def __format(metadata: dict):
|
||||||
return {key.value if isinstance(key, Enum) else key: val for key, val in metadata.items()}
|
return {key.value if isinstance(key, Enum) else key: val for key, val in metadata.items()}
|
||||||
|
|
||||||
def __call__(self, metadata: List[dict]):
|
def format(self, metadata: List[dict]):
|
||||||
return map(self.format, metadata)
|
return map(self.__format, metadata)
|
||||||
7
image_prediction/formatter/formatters/identity.py
Normal file
7
image_prediction/formatter/formatters/identity.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from image_prediction.formatter.formatter import Formatter
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityFormatter(Formatter):
|
||||||
|
|
||||||
|
def format(self, obj):
|
||||||
|
return obj
|
||||||
@ -2,7 +2,7 @@ from operator import itemgetter
|
|||||||
from typing import Mapping, Iterable
|
from typing import Mapping, Iterable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from funcy import rcompose
|
from funcy import rcompose, rpartial
|
||||||
|
|
||||||
from image_prediction.exceptions import UnexpectedLabelFormat
|
from image_prediction.exceptions import UnexpectedLabelFormat
|
||||||
from image_prediction.label_mapper.mapper import LabelMapper
|
from image_prediction.label_mapper.mapper import LabelMapper
|
||||||
@ -13,7 +13,7 @@ class ProbabilityMapper(LabelMapper):
|
|||||||
self.__labels = labels
|
self.__labels = labels
|
||||||
# String conversion in the middle due to floating point precision issues.
|
# String conversion in the middle due to floating point precision issues.
|
||||||
# See: https://stackoverflow.com/questions/56820/round-doesnt-seem-to-be-rounding-properly
|
# See: https://stackoverflow.com/questions/56820/round-doesnt-seem-to-be-rounding-properly
|
||||||
self.__rounder = rcompose(lambda d: round(d, 4), str, float)
|
self.__rounder = rcompose(rpartial(round, 4), str, float)
|
||||||
|
|
||||||
def __validate_array_label_format(self, probabilities: np.ndarray) -> None:
|
def __validate_array_label_format(self, probabilities: np.ndarray) -> None:
|
||||||
if not len(probabilities) == len(self.__labels):
|
if not len(probabilities) == len(self.__labels):
|
||||||
|
|||||||
15
test/unit_tests/compositor_test.py
Normal file
15
test/unit_tests/compositor_test.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||||
|
from image_prediction.formatter.formatters.identity import IdentityFormatter
|
||||||
|
from image_prediction.compositor.compositor import FormatterCompositor
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_formatter(metadata):
|
||||||
|
compositor = FormatterCompositor(IdentityFormatter())
|
||||||
|
assert metadata == compositor(metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def test_two_formatters(metadata, metadata_formatted):
|
||||||
|
compositor = FormatterCompositor(IdentityFormatter(), EnumFormatter())
|
||||||
|
assert metadata_formatted == list(compositor(metadata))
|
||||||
@ -1,5 +1,22 @@
|
|||||||
from image_prediction.formatter.formatters.info_formatter import EnumFormatter
|
import pytest
|
||||||
|
|
||||||
|
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
|
||||||
|
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||||
|
|
||||||
|
|
||||||
def test_formatter(metadata, metadata_formatted):
|
def test_formatter(metadata, metadata_formatted):
|
||||||
assert list(EnumFormatter()(metadata)) == metadata_formatted
|
assert list(EnumFormatter()(metadata)) == metadata_formatted
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def snake_case_data():
|
||||||
|
return {"a_key": {"key": None, "key_2": ["may_not_be_changed"]}, 2: {"yet_another_key": None}}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def camel_case_data():
|
||||||
|
return {"aKey": {"key": None, "key2": ["may_not_be_changed"]}, 2: {"yetAnotherKey": None}}
|
||||||
|
|
||||||
|
|
||||||
|
# def test_camel_case_key_formatter(snake_case_data, camel_case_data):
|
||||||
|
# assert Snake2CamelCaseKeyFormatter()(snake_case_data) == camel_case_data
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user