refactoring: introduced key mapper base class and proba mapper key enum
This commit is contained in:
parent
0cefef4e15
commit
3eaf9dc0e1
@ -1,33 +1,13 @@
|
||||
from typing import Iterable
|
||||
|
||||
from image_prediction.formatter.formatter import Formatter
|
||||
from image_prediction.formatter.formatters.key_formatter import KeyFormatter
|
||||
|
||||
|
||||
class Snake2CamelCaseKeyFormatter(Formatter):
|
||||
class Snake2CamelCaseKeyFormatter(KeyFormatter):
|
||||
|
||||
@staticmethod
|
||||
def __format_key(key):
|
||||
def format_key(key):
|
||||
|
||||
if isinstance(key, str):
|
||||
head, *tail = key.split("_")
|
||||
return head + "".join(map(str.title, tail))
|
||||
else:
|
||||
return key
|
||||
|
||||
def __format(self, data):
|
||||
|
||||
# If we wanted to do this properly, we would need handlers for all expected types and dispatch based
|
||||
# on a type comparison. This is too much engineering for the limited use-case of this class though.
|
||||
if isinstance(data, Iterable) and not isinstance(data, dict) and not isinstance(data, str):
|
||||
f = map(self.__format, data)
|
||||
return type(data)(f) if not isinstance(data, map) else f
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
keys_formatted = list(map(self.__format_key, data))
|
||||
|
||||
return dict(zip(keys_formatted, map(self.__format, data.values())))
|
||||
|
||||
def format(self, data):
|
||||
return self.__format(data)
|
||||
|
||||
29
image_prediction/formatter/formatters/key_formatter.py
Normal file
29
image_prediction/formatter/formatters/key_formatter.py
Normal file
@ -0,0 +1,29 @@
|
||||
import abc
|
||||
from typing import Iterable
|
||||
|
||||
from image_prediction.formatter.formatter import Formatter
|
||||
|
||||
|
||||
class KeyFormatter(Formatter):
|
||||
|
||||
@abc.abstractmethod
|
||||
def format_key(self, key):
|
||||
raise NotImplementedError
|
||||
|
||||
def __format(self, data):
|
||||
|
||||
# If we wanted to do this properly, we would need handlers for all expected types and dispatch based
|
||||
# on a type comparison. This is too much engineering for the limited use-case of this class though.
|
||||
if isinstance(data, Iterable) and not isinstance(data, dict) and not isinstance(data, str):
|
||||
f = map(self.__format, data)
|
||||
return type(data)(f) if not isinstance(data, map) else f
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
keys_formatted = list(map(self.format_key, data))
|
||||
|
||||
return dict(zip(keys_formatted, map(self.__format, data.values())))
|
||||
|
||||
def format(self, data):
|
||||
return self.__format(data)
|
||||
@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from operator import itemgetter
|
||||
from typing import Mapping, Iterable
|
||||
|
||||
@ -8,6 +9,12 @@ from image_prediction.exceptions import UnexpectedLabelFormat
|
||||
from image_prediction.label_mapper.mapper import LabelMapper
|
||||
|
||||
|
||||
class ProbabilityMapperKeys(Enum):
|
||||
LABEL = "label"
|
||||
PROBABILITIES = "probabilities"
|
||||
|
||||
|
||||
|
||||
class ProbabilityMapper(LabelMapper):
|
||||
def __init__(self, labels: Mapping[int, str]):
|
||||
self.__labels = labels
|
||||
@ -27,7 +34,7 @@ class ProbabilityMapper(LabelMapper):
|
||||
sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True)
|
||||
)
|
||||
most_likely = [*cls2prob][0]
|
||||
return {"label": most_likely, "probabilities": cls2prob}
|
||||
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: cls2prob}
|
||||
|
||||
def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]:
|
||||
return map(self.__map_array, probabilities)
|
||||
|
||||
@ -22,7 +22,7 @@ from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
||||
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
||||
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper, ProbabilityMapperKeys
|
||||
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
||||
from image_prediction.model_loader.loader import ModelLoader
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||
@ -200,7 +200,7 @@ def batch_of_expected_label_to_probability_mappings(batch_of_expected_probabilit
|
||||
def map_probabilities(probabilities):
|
||||
lbl2prob = dict(sorted(zip(classes, map(rounder, probabilities)), key=itemgetter(1), reverse=True))
|
||||
most_likely = [*lbl2prob][0]
|
||||
return {"label": most_likely, "probabilities": lbl2prob}
|
||||
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob}
|
||||
|
||||
rounder = rcompose(partial(np.round, decimals=4), float)
|
||||
return list(map(map_probabilities, batch_of_expected_probability_arrays))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user