support for array prediction format
This commit is contained in:
parent
15c0b73034
commit
e9489287bd
@ -23,18 +23,28 @@ class Classifier:
|
||||
self.__estimator_adapter = estimator_adapter
|
||||
self._classes = classes
|
||||
|
||||
def __validate_prediction_format(self, prediction):
|
||||
def __validate_dict_prediction_format(self, prediction):
|
||||
if not max(prediction.keys) <= len(self._classes):
|
||||
raise UnexpectedPredictionFormat(f"Received prediction in an unexpected format: {prediction}")
|
||||
|
||||
def __validate_array_prediction_format(self, prediction):
|
||||
if not len(prediction) == len(self._classes):
|
||||
raise UnexpectedPredictionFormat(
|
||||
f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)}."
|
||||
)
|
||||
|
||||
def __format_prediction(self, prediction):
|
||||
if isinstance(prediction, int):
|
||||
return self._classes[prediction]
|
||||
|
||||
elif isinstance(prediction, dict):
|
||||
self.__validate_prediction_format(prediction)
|
||||
self.__validate_dict_prediction_format(prediction)
|
||||
return {self._classes[cls_idx] for cls_idx, prob in prediction.items()}
|
||||
|
||||
elif isinstance(prediction, np.ndarray):
|
||||
self.__validate_array_prediction_format(prediction)
|
||||
return dict(zip(self._classes, prediction))
|
||||
|
||||
else:
|
||||
return prediction
|
||||
|
||||
|
||||
@ -3,7 +3,4 @@ class EstimatorAdapter:
|
||||
self.estimator = estimator
|
||||
|
||||
def predict(self, batch):
|
||||
return self.estimator.predict(batch)
|
||||
|
||||
def predict_proba(self, batch):
|
||||
return self.estimator.predict_proba(batch)
|
||||
return self.estimator(batch)
|
||||
|
||||
@ -19,4 +19,4 @@ class UnexpectedPredictionFormat(ValueError):
|
||||
|
||||
|
||||
class IncorrectInstantiation(RuntimeError):
|
||||
pass
|
||||
pass
|
||||
|
||||
@ -21,7 +21,7 @@ class ExtractorClassifier:
|
||||
return []
|
||||
|
||||
predictions = self.classifier(images)
|
||||
responses = ({"label": lbl, **mdt} for lbl, mdt in zip(predictions, metadata))
|
||||
responses = ({"prediction": prd, **mdt} for prd, mdt in zip(predictions, metadata))
|
||||
return responses
|
||||
|
||||
def __call__(self, obj) -> Iterable[ImageMetadataPair]:
|
||||
|
||||
@ -6,11 +6,14 @@ class PredictionModelHandle:
|
||||
|
||||
def __init__(self, model_handle):
|
||||
self.__model_handle = model_handle
|
||||
self.__predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
|
||||
self.__predict_proba = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
|
||||
|
||||
def predict(self, *args, **kwargs):
|
||||
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
|
||||
return predict(*args, **kwargs)
|
||||
return self.__predict(*args, **kwargs)
|
||||
|
||||
def predict_proba(self, *args, **kwargs):
|
||||
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
|
||||
return predict(*args, **kwargs)
|
||||
return self.__predict_proba(*args, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.predict_proba(*args, **kwargs)
|
||||
|
||||
@ -52,6 +52,9 @@ class EstimatorMock:
|
||||
def predict(batch):
|
||||
return [None for _ in batch]
|
||||
|
||||
def __call__(self, batch):
|
||||
return self.predict(batch)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch):
|
||||
|
||||
@ -11,5 +11,5 @@ from image_prediction.extractor_classifier.extractor_classifier import Extractor
|
||||
def test_extractor_classifier(image_extractor, image_classifier, images, expected_predictions):
|
||||
extractor_classifier = ExtractorClassifier(image_extractor, image_classifier)
|
||||
results = extractor_classifier(images)
|
||||
labels = list(map(itemgetter("label"), results))
|
||||
labels = list(map(itemgetter("prediction"), results))
|
||||
assert labels == expected_predictions
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user