diff --git a/image_prediction/formatter/formatters/camel_case.py b/image_prediction/formatter/formatters/camel_case.py index a91a727..0b12781 100644 --- a/image_prediction/formatter/formatters/camel_case.py +++ b/image_prediction/formatter/formatters/camel_case.py @@ -1,33 +1,13 @@ -from typing import Iterable - -from image_prediction.formatter.formatter import Formatter +from image_prediction.formatter.formatters.key_formatter import KeyFormatter -class Snake2CamelCaseKeyFormatter(Formatter): +class Snake2CamelCaseKeyFormatter(KeyFormatter): @staticmethod - def __format_key(key): + def format_key(key): if isinstance(key, str): head, *tail = key.split("_") return head + "".join(map(str.title, tail)) else: return key - - def __format(self, data): - - # If we wanted to do this properly, we would need handlers for all expected types and dispatch based - # on a type comparison. This is too much engineering for the limited use-case of this class though. - if isinstance(data, Iterable) and not isinstance(data, dict) and not isinstance(data, str): - f = map(self.__format, data) - return type(data)(f) if not isinstance(data, map) else f - - if not isinstance(data, dict): - return data - - keys_formatted = list(map(self.__format_key, data)) - - return dict(zip(keys_formatted, map(self.__format, data.values()))) - - def format(self, data): - return self.__format(data) diff --git a/image_prediction/formatter/formatters/key_formatter.py b/image_prediction/formatter/formatters/key_formatter.py new file mode 100644 index 0000000..fbf0efd --- /dev/null +++ b/image_prediction/formatter/formatters/key_formatter.py @@ -0,0 +1,29 @@ +import abc +from typing import Iterable + +from image_prediction.formatter.formatter import Formatter + + +class KeyFormatter(Formatter): + + @abc.abstractmethod + def format_key(self, key): + raise NotImplementedError + + def __format(self, data): + + # If we wanted to do this properly, we would need handlers for all expected types and dispatch based + # on a type comparison. This is too much engineering for the limited use-case of this class though. + if isinstance(data, Iterable) and not isinstance(data, dict) and not isinstance(data, str): + f = map(self.__format, data) + return type(data)(f) if not isinstance(data, map) else f + + if not isinstance(data, dict): + return data + + keys_formatted = list(map(self.format_key, data)) + + return dict(zip(keys_formatted, map(self.__format, data.values()))) + + def format(self, data): + return self.__format(data) diff --git a/image_prediction/label_mapper/mappers/probability.py b/image_prediction/label_mapper/mappers/probability.py index 66825bb..f201599 100644 --- a/image_prediction/label_mapper/mappers/probability.py +++ b/image_prediction/label_mapper/mappers/probability.py @@ -1,3 +1,4 @@ +from enum import Enum from operator import itemgetter from typing import Mapping, Iterable @@ -8,6 +9,12 @@ from image_prediction.exceptions import UnexpectedLabelFormat from image_prediction.label_mapper.mapper import LabelMapper +class ProbabilityMapperKeys(Enum): + LABEL = "label" + PROBABILITIES = "probabilities" + + + class ProbabilityMapper(LabelMapper): def __init__(self, labels: Mapping[int, str]): self.__labels = labels @@ -27,7 +34,7 @@ class ProbabilityMapper(LabelMapper): sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True) ) most_likely = [*cls2prob][0] - return {"label": most_likely, "probabilities": cls2prob} + return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: cls2prob} def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]: return map(self.__map_array, probabilities) diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index 671ff3f..9f7c5c7 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -22,7 +22,7 @@ from image_prediction.image_extractor.extractors.mock import ImageExtractorMock from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.info import Info from image_prediction.label_mapper.mappers.numeric import IndexMapper -from image_prediction.label_mapper.mappers.probability import ProbabilityMapper +from image_prediction.label_mapper.mappers.probability import ProbabilityMapper, ProbabilityMapperKeys from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loaders.mlflow import MlflowConnector @@ -200,7 +200,7 @@ def batch_of_expected_label_to_probability_mappings(batch_of_expected_probabilit def map_probabilities(probabilities): lbl2prob = dict(sorted(zip(classes, map(rounder, probabilities)), key=itemgetter(1), reverse=True)) most_likely = [*lbl2prob][0] - return {"label": most_likely, "probabilities": lbl2prob} + return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob} rounder = rcompose(partial(np.round, decimals=4), float) return list(map(map_probabilities, batch_of_expected_probability_arrays))