refactoring: introduced key mapper base class and proba mapper key enum
This commit is contained in:
parent
0cefef4e15
commit
3eaf9dc0e1
@ -1,33 +1,13 @@
|
|||||||
from typing import Iterable
|
from image_prediction.formatter.formatters.key_formatter import KeyFormatter
|
||||||
|
|
||||||
from image_prediction.formatter.formatter import Formatter
|
|
||||||
|
|
||||||
|
|
||||||
class Snake2CamelCaseKeyFormatter(Formatter):
|
class Snake2CamelCaseKeyFormatter(KeyFormatter):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __format_key(key):
|
def format_key(key):
|
||||||
|
|
||||||
if isinstance(key, str):
|
if isinstance(key, str):
|
||||||
head, *tail = key.split("_")
|
head, *tail = key.split("_")
|
||||||
return head + "".join(map(str.title, tail))
|
return head + "".join(map(str.title, tail))
|
||||||
else:
|
else:
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def __format(self, data):
|
|
||||||
|
|
||||||
# If we wanted to do this properly, we would need handlers for all expected types and dispatch based
|
|
||||||
# on a type comparison. This is too much engineering for the limited use-case of this class though.
|
|
||||||
if isinstance(data, Iterable) and not isinstance(data, dict) and not isinstance(data, str):
|
|
||||||
f = map(self.__format, data)
|
|
||||||
return type(data)(f) if not isinstance(data, map) else f
|
|
||||||
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
return data
|
|
||||||
|
|
||||||
keys_formatted = list(map(self.__format_key, data))
|
|
||||||
|
|
||||||
return dict(zip(keys_formatted, map(self.__format, data.values())))
|
|
||||||
|
|
||||||
def format(self, data):
|
|
||||||
return self.__format(data)
|
|
||||||
|
|||||||
29
image_prediction/formatter/formatters/key_formatter.py
Normal file
29
image_prediction/formatter/formatters/key_formatter.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import abc
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from image_prediction.formatter.formatter import Formatter
|
||||||
|
|
||||||
|
|
||||||
|
class KeyFormatter(Formatter):
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def format_key(self, key):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __format(self, data):
|
||||||
|
|
||||||
|
# If we wanted to do this properly, we would need handlers for all expected types and dispatch based
|
||||||
|
# on a type comparison. This is too much engineering for the limited use-case of this class though.
|
||||||
|
if isinstance(data, Iterable) and not isinstance(data, dict) and not isinstance(data, str):
|
||||||
|
f = map(self.__format, data)
|
||||||
|
return type(data)(f) if not isinstance(data, map) else f
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return data
|
||||||
|
|
||||||
|
keys_formatted = list(map(self.format_key, data))
|
||||||
|
|
||||||
|
return dict(zip(keys_formatted, map(self.__format, data.values())))
|
||||||
|
|
||||||
|
def format(self, data):
|
||||||
|
return self.__format(data)
|
||||||
@ -1,3 +1,4 @@
|
|||||||
|
from enum import Enum
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Mapping, Iterable
|
from typing import Mapping, Iterable
|
||||||
|
|
||||||
@ -8,6 +9,12 @@ from image_prediction.exceptions import UnexpectedLabelFormat
|
|||||||
from image_prediction.label_mapper.mapper import LabelMapper
|
from image_prediction.label_mapper.mapper import LabelMapper
|
||||||
|
|
||||||
|
|
||||||
|
class ProbabilityMapperKeys(Enum):
|
||||||
|
LABEL = "label"
|
||||||
|
PROBABILITIES = "probabilities"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ProbabilityMapper(LabelMapper):
|
class ProbabilityMapper(LabelMapper):
|
||||||
def __init__(self, labels: Mapping[int, str]):
|
def __init__(self, labels: Mapping[int, str]):
|
||||||
self.__labels = labels
|
self.__labels = labels
|
||||||
@ -27,7 +34,7 @@ class ProbabilityMapper(LabelMapper):
|
|||||||
sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True)
|
sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True)
|
||||||
)
|
)
|
||||||
most_likely = [*cls2prob][0]
|
most_likely = [*cls2prob][0]
|
||||||
return {"label": most_likely, "probabilities": cls2prob}
|
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: cls2prob}
|
||||||
|
|
||||||
def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]:
|
def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]:
|
||||||
return map(self.__map_array, probabilities)
|
return map(self.__map_array, probabilities)
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
|||||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||||
from image_prediction.info import Info
|
from image_prediction.info import Info
|
||||||
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
||||||
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper, ProbabilityMapperKeys
|
||||||
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
||||||
from image_prediction.model_loader.loader import ModelLoader
|
from image_prediction.model_loader.loader import ModelLoader
|
||||||
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||||
@ -200,7 +200,7 @@ def batch_of_expected_label_to_probability_mappings(batch_of_expected_probabilit
|
|||||||
def map_probabilities(probabilities):
|
def map_probabilities(probabilities):
|
||||||
lbl2prob = dict(sorted(zip(classes, map(rounder, probabilities)), key=itemgetter(1), reverse=True))
|
lbl2prob = dict(sorted(zip(classes, map(rounder, probabilities)), key=itemgetter(1), reverse=True))
|
||||||
most_likely = [*lbl2prob][0]
|
most_likely = [*lbl2prob][0]
|
||||||
return {"label": most_likely, "probabilities": lbl2prob}
|
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob}
|
||||||
|
|
||||||
rounder = rcompose(partial(np.round, decimals=4), float)
|
rounder = rcompose(partial(np.round, decimals=4), float)
|
||||||
return list(map(map_probabilities, batch_of_expected_probability_arrays))
|
return list(map(map_probabilities, batch_of_expected_probability_arrays))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user