2022-03-29 23:56:22 +02:00

60 lines
2.3 KiB
Python

from typing import Mapping, List, Union, Tuple
import numpy as np
from PIL.Image import Image
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.exceptions import UnexpectedPredictionFormat
from image_prediction.utils import get_logger
logger = get_logger()
class Classifier:
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
"""Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used
an EstimatorAdapter must be implemented.
Args:
estimator_adapter: adapter for a given estimator backend; expected to be a classifier that returns mappings
from numeric labels to probabilities as predictions or numeric labels
classes: mapping from a numerical label to a human-readable label for classes
"""
self.__estimator_adapter = estimator_adapter
self._classes = classes
def __validate_dict_prediction_format(self, prediction):
if not max(prediction.keys) <= len(self._classes):
raise UnexpectedPredictionFormat(f"Received prediction in an unexpected format: {prediction}")
def __validate_array_prediction_format(self, prediction):
if not len(prediction) == len(self._classes):
raise UnexpectedPredictionFormat(
f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)}."
)
def __format_prediction(self, prediction):
if isinstance(prediction, int):
return self._classes[prediction]
elif isinstance(prediction, dict):
self.__validate_dict_prediction_format(prediction)
return {self._classes[cls_idx] for cls_idx, prob in prediction.items()}
elif isinstance(prediction, np.ndarray):
self.__validate_array_prediction_format(prediction)
return dict(zip(self._classes, prediction))
else:
return prediction
def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]:
if not isinstance(batch, tuple) and batch.shape[0] == 0:
return []
return list(map(self.__format_prediction, self.__estimator_adapter.predict(batch)))
def __call__(self, batch: np.array) -> List[str]:
return self.predict(batch)