from typing import Mapping, List, Union, Tuple import numpy as np from PIL.Image import Image from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.exceptions import UnexpectedPredictionFormat from image_prediction.utils import get_logger logger = get_logger() class Classifier: def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]): """Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used an EstimatorAdapter must be implemented. Args: estimator_adapter: adapter for a given estimator backend; expected to be a classifier that returns mappings from numeric labels to probabilities as predictions or numeric labels classes: mapping from a numerical label to a human-readable label for classes """ self.__estimator_adapter = estimator_adapter self._classes = classes def __validate_dict_prediction_format(self, prediction): if not max(prediction.keys) <= len(self._classes): raise UnexpectedPredictionFormat(f"Received prediction in an unexpected format: {prediction}") def __validate_array_prediction_format(self, prediction): if not len(prediction) == len(self._classes): raise UnexpectedPredictionFormat( f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)}." ) def __format_prediction(self, prediction): if isinstance(prediction, int): return self._classes[prediction] elif isinstance(prediction, dict): self.__validate_dict_prediction_format(prediction) return {self._classes[cls_idx] for cls_idx, prob in prediction.items()} elif isinstance(prediction, np.ndarray): self.__validate_array_prediction_format(prediction) return dict(zip(self._classes, prediction)) else: return prediction def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]: if not isinstance(batch, tuple) and batch.shape[0] == 0: return [] return list(map(self.__format_prediction, self.__estimator_adapter.predict(batch))) def __call__(self, batch: np.array) -> List[str]: return self.predict(batch)