support for array prediction format

This commit is contained in:
Matthias Bisping 2022-03-29 23:56:22 +02:00
parent 15c0b73034
commit e9489287bd
7 changed files with 26 additions and 13 deletions

View File

@ -23,18 +23,28 @@ class Classifier:
self.__estimator_adapter = estimator_adapter
self._classes = classes
def __validate_prediction_format(self, prediction):
def __validate_dict_prediction_format(self, prediction):
if not max(prediction.keys) <= len(self._classes):
raise UnexpectedPredictionFormat(f"Received prediction in an unexpected format: {prediction}")
def __validate_array_prediction_format(self, prediction):
if not len(prediction) == len(self._classes):
raise UnexpectedPredictionFormat(
f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)}."
)
def __format_prediction(self, prediction):
if isinstance(prediction, int):
return self._classes[prediction]
elif isinstance(prediction, dict):
self.__validate_prediction_format(prediction)
self.__validate_dict_prediction_format(prediction)
return {self._classes[cls_idx] for cls_idx, prob in prediction.items()}
elif isinstance(prediction, np.ndarray):
self.__validate_array_prediction_format(prediction)
return dict(zip(self._classes, prediction))
else:
return prediction

View File

@ -3,7 +3,4 @@ class EstimatorAdapter:
self.estimator = estimator
def predict(self, batch):
return self.estimator.predict(batch)
def predict_proba(self, batch):
return self.estimator.predict_proba(batch)
return self.estimator(batch)

View File

@ -19,4 +19,4 @@ class UnexpectedPredictionFormat(ValueError):
class IncorrectInstantiation(RuntimeError):
pass
pass

View File

@ -21,7 +21,7 @@ class ExtractorClassifier:
return []
predictions = self.classifier(images)
responses = ({"label": lbl, **mdt} for lbl, mdt in zip(predictions, metadata))
responses = ({"prediction": prd, **mdt} for prd, mdt in zip(predictions, metadata))
return responses
def __call__(self, obj) -> Iterable[ImageMetadataPair]:

View File

@ -6,11 +6,14 @@ class PredictionModelHandle:
def __init__(self, model_handle):
self.__model_handle = model_handle
self.__predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
self.__predict_proba = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
def predict(self, *args, **kwargs):
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
return predict(*args, **kwargs)
return self.__predict(*args, **kwargs)
def predict_proba(self, *args, **kwargs):
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
return predict(*args, **kwargs)
return self.__predict_proba(*args, **kwargs)
def __call__(self, *args, **kwargs):
return self.predict_proba(*args, **kwargs)

View File

@ -52,6 +52,9 @@ class EstimatorMock:
def predict(batch):
return [None for _ in batch]
def __call__(self, batch):
return self.predict(batch)
@pytest.fixture
def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch):

View File

@ -11,5 +11,5 @@ from image_prediction.extractor_classifier.extractor_classifier import Extractor
def test_extractor_classifier(image_extractor, image_classifier, images, expected_predictions):
extractor_classifier = ExtractorClassifier(image_extractor, image_classifier)
results = extractor_classifier(images)
labels = list(map(itemgetter("label"), results))
labels = list(map(itemgetter("prediction"), results))
assert labels == expected_predictions