from operator import itemgetter from typing import Mapping, List, Union, Tuple import numpy as np from PIL.Image import Image from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.exceptions import UnexpectedLabelFormat from image_prediction.utils import get_logger logger = get_logger() class Classifier: def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]): """Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used an EstimatorAdapter must be implemented. Args: estimator_adapter: adapter for a given estimator backend classes: mapping from a numerical label to a human-readable label for classes """ self.__estimator_adapter = estimator_adapter self._classes = classes def __validate_array_prediction_format(self, prediction): if not len(prediction) == len(self._classes): raise UnexpectedLabelFormat( f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)})." ) def __validate_int_prediction_format(self, prediction): if not 0 <= prediction <= len(self._classes): raise UnexpectedLabelFormat( f"Received class index '{prediction}' as prediction that has no associated class label." ) def __format_array_prediction_format(self, prediction): cls2prob = dict(sorted(zip(self._classes, prediction), key=itemgetter(1), reverse=True)) most_likely = [*cls2prob][0] return {"label": most_likely, "probabilities": cls2prob} def __format_prediction(self, prediction): if isinstance(prediction, int): self.__validate_int_prediction_format(prediction) return self._classes[prediction] elif isinstance(prediction, np.ndarray): self.__validate_array_prediction_format(prediction) return self.__format_array_prediction_format(prediction) else: return prediction def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]: if not isinstance(batch, tuple) and batch.shape[0] == 0: return [] return list(map(self.__format_prediction, self.__estimator_adapter.predict(batch))) def __call__(self, batch: np.array) -> List[str]: return self.predict(batch)