64 lines
2.4 KiB
Python
64 lines
2.4 KiB
Python
from operator import itemgetter
|
|
from typing import Mapping, List, Union, Tuple
|
|
|
|
import numpy as np
|
|
from PIL.Image import Image
|
|
|
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
|
from image_prediction.exceptions import UnexpectedLabelFormat
|
|
from image_prediction.utils import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class Classifier:
|
|
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
|
|
"""Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used
|
|
an EstimatorAdapter must be implemented.
|
|
|
|
Args:
|
|
estimator_adapter: adapter for a given estimator backend
|
|
classes: mapping from a numerical label to a human-readable label for classes
|
|
"""
|
|
self.__estimator_adapter = estimator_adapter
|
|
self._classes = classes
|
|
|
|
def __validate_array_prediction_format(self, prediction):
|
|
if not len(prediction) == len(self._classes):
|
|
raise UnexpectedLabelFormat(
|
|
f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)})."
|
|
)
|
|
|
|
def __validate_int_prediction_format(self, prediction):
|
|
if not 0 <= prediction <= len(self._classes):
|
|
raise UnexpectedLabelFormat(
|
|
f"Received class index '{prediction}' as prediction that has no associated class label."
|
|
)
|
|
|
|
def __format_array_prediction_format(self, prediction):
|
|
cls2prob = dict(sorted(zip(self._classes, prediction), key=itemgetter(1), reverse=True))
|
|
most_likely = [*cls2prob][0]
|
|
return {"label": most_likely, "probabilities": cls2prob}
|
|
|
|
def __format_prediction(self, prediction):
|
|
if isinstance(prediction, int):
|
|
self.__validate_int_prediction_format(prediction)
|
|
return self._classes[prediction]
|
|
|
|
elif isinstance(prediction, np.ndarray):
|
|
self.__validate_array_prediction_format(prediction)
|
|
return self.__format_array_prediction_format(prediction)
|
|
|
|
else:
|
|
return prediction
|
|
|
|
def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]:
|
|
|
|
if not isinstance(batch, tuple) and batch.shape[0] == 0:
|
|
return []
|
|
|
|
return list(map(self.__format_prediction, self.__estimator_adapter.predict(batch)))
|
|
|
|
def __call__(self, batch: np.array) -> List[str]:
|
|
return self.predict(batch)
|