refactoring: introduced key mapper base class and proba mapper key enum

This commit is contained in:
Matthias Bisping 2022-03-31 16:55:58 +02:00
parent 0cefef4e15
commit 3eaf9dc0e1
4 changed files with 42 additions and 26 deletions

View File

@ -1,33 +1,13 @@
from typing import Iterable from image_prediction.formatter.formatters.key_formatter import KeyFormatter
from image_prediction.formatter.formatter import Formatter
class Snake2CamelCaseKeyFormatter(Formatter): class Snake2CamelCaseKeyFormatter(KeyFormatter):
@staticmethod @staticmethod
def __format_key(key): def format_key(key):
if isinstance(key, str): if isinstance(key, str):
head, *tail = key.split("_") head, *tail = key.split("_")
return head + "".join(map(str.title, tail)) return head + "".join(map(str.title, tail))
else: else:
return key return key
def __format(self, data):
# If we wanted to do this properly, we would need handlers for all expected types and dispatch based
# on a type comparison. This is too much engineering for the limited use-case of this class though.
if isinstance(data, Iterable) and not isinstance(data, dict) and not isinstance(data, str):
f = map(self.__format, data)
return type(data)(f) if not isinstance(data, map) else f
if not isinstance(data, dict):
return data
keys_formatted = list(map(self.__format_key, data))
return dict(zip(keys_formatted, map(self.__format, data.values())))
def format(self, data):
return self.__format(data)

View File

@ -0,0 +1,29 @@
import abc
from typing import Iterable
from image_prediction.formatter.formatter import Formatter
class KeyFormatter(Formatter):
@abc.abstractmethod
def format_key(self, key):
raise NotImplementedError
def __format(self, data):
# If we wanted to do this properly, we would need handlers for all expected types and dispatch based
# on a type comparison. This is too much engineering for the limited use-case of this class though.
if isinstance(data, Iterable) and not isinstance(data, dict) and not isinstance(data, str):
f = map(self.__format, data)
return type(data)(f) if not isinstance(data, map) else f
if not isinstance(data, dict):
return data
keys_formatted = list(map(self.format_key, data))
return dict(zip(keys_formatted, map(self.__format, data.values())))
def format(self, data):
return self.__format(data)

View File

@ -1,3 +1,4 @@
from enum import Enum
from operator import itemgetter from operator import itemgetter
from typing import Mapping, Iterable from typing import Mapping, Iterable
@ -8,6 +9,12 @@ from image_prediction.exceptions import UnexpectedLabelFormat
from image_prediction.label_mapper.mapper import LabelMapper from image_prediction.label_mapper.mapper import LabelMapper
class ProbabilityMapperKeys(Enum):
LABEL = "label"
PROBABILITIES = "probabilities"
class ProbabilityMapper(LabelMapper): class ProbabilityMapper(LabelMapper):
def __init__(self, labels: Mapping[int, str]): def __init__(self, labels: Mapping[int, str]):
self.__labels = labels self.__labels = labels
@ -27,7 +34,7 @@ class ProbabilityMapper(LabelMapper):
sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True) sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True)
) )
most_likely = [*cls2prob][0] most_likely = [*cls2prob][0]
return {"label": most_likely, "probabilities": cls2prob} return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: cls2prob}
def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]: def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]:
return map(self.__map_array, probabilities) return map(self.__map_array, probabilities)

View File

@ -22,7 +22,7 @@ from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.info import Info from image_prediction.info import Info
from image_prediction.label_mapper.mappers.numeric import IndexMapper from image_prediction.label_mapper.mappers.numeric import IndexMapper
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper from image_prediction.label_mapper.mappers.probability import ProbabilityMapper, ProbabilityMapperKeys
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector from image_prediction.model_loader.loaders.mlflow import MlflowConnector
@ -200,7 +200,7 @@ def batch_of_expected_label_to_probability_mappings(batch_of_expected_probabilit
def map_probabilities(probabilities): def map_probabilities(probabilities):
lbl2prob = dict(sorted(zip(classes, map(rounder, probabilities)), key=itemgetter(1), reverse=True)) lbl2prob = dict(sorted(zip(classes, map(rounder, probabilities)), key=itemgetter(1), reverse=True))
most_likely = [*lbl2prob][0] most_likely = [*lbl2prob][0]
return {"label": most_likely, "probabilities": lbl2prob} return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob}
rounder = rcompose(partial(np.round, decimals=4), float) rounder = rcompose(partial(np.round, decimals=4), float)
return list(map(map_probabilities, batch_of_expected_probability_arrays)) return list(map(map_probabilities, batch_of_expected_probability_arrays))