support for different prediction formats
This commit is contained in:
parent
7a64af156b
commit
15c0b73034
@ -4,6 +4,7 @@ import numpy as np
|
||||
from PIL.Image import Image
|
||||
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
from image_prediction.exceptions import UnexpectedPredictionFormat
|
||||
from image_prediction.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@ -15,19 +16,34 @@ class Classifier:
|
||||
an EstimatorAdapter must be implemented.
|
||||
|
||||
Args:
|
||||
estimator_adapter: adapter for a given estimator backend; expected to be a classifier that returns numeric
|
||||
labels as predictions
|
||||
estimator_adapter: adapter for a given estimator backend; expected to be a classifier that returns mappings
|
||||
from numeric labels to probabilities as predictions or numeric labels
|
||||
classes: mapping from a numerical label to a human-readable label for classes
|
||||
"""
|
||||
self.__estimator_adapter = estimator_adapter
|
||||
self._classes = classes
|
||||
|
||||
def __validate_prediction_format(self, prediction):
|
||||
if not max(prediction.keys) <= len(self._classes):
|
||||
raise UnexpectedPredictionFormat(f"Received prediction in an unexpected format: {prediction}")
|
||||
|
||||
def __format_prediction(self, prediction):
|
||||
if isinstance(prediction, int):
|
||||
return self._classes[prediction]
|
||||
|
||||
elif isinstance(prediction, dict):
|
||||
self.__validate_prediction_format(prediction)
|
||||
return {self._classes[cls_idx] for cls_idx, prob in prediction.items()}
|
||||
|
||||
else:
|
||||
return prediction
|
||||
|
||||
def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]:
|
||||
|
||||
if not isinstance(batch, tuple) and batch.shape[0] == 0:
|
||||
return []
|
||||
|
||||
return [self._classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
|
||||
return list(map(self.__format_prediction, self.__estimator_adapter.predict(batch)))
|
||||
|
||||
def __call__(self, batch: np.array) -> List[str]:
|
||||
return self.predict(batch)
|
||||
|
||||
@ -14,5 +14,9 @@ class UnknownDatabaseType(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class IncorrectInstantiation(RuntimeError):
|
||||
class UnexpectedPredictionFormat(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class IncorrectInstantiation(RuntimeError):
|
||||
pass
|
||||
Loading…
x
Reference in New Issue
Block a user