support for array prediction format
This commit is contained in:
parent
15c0b73034
commit
e9489287bd
@ -23,18 +23,28 @@ class Classifier:
|
|||||||
self.__estimator_adapter = estimator_adapter
|
self.__estimator_adapter = estimator_adapter
|
||||||
self._classes = classes
|
self._classes = classes
|
||||||
|
|
||||||
def __validate_prediction_format(self, prediction):
|
def __validate_dict_prediction_format(self, prediction):
|
||||||
if not max(prediction.keys) <= len(self._classes):
|
if not max(prediction.keys) <= len(self._classes):
|
||||||
raise UnexpectedPredictionFormat(f"Received prediction in an unexpected format: {prediction}")
|
raise UnexpectedPredictionFormat(f"Received prediction in an unexpected format: {prediction}")
|
||||||
|
|
||||||
|
def __validate_array_prediction_format(self, prediction):
|
||||||
|
if not len(prediction) == len(self._classes):
|
||||||
|
raise UnexpectedPredictionFormat(
|
||||||
|
f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)}."
|
||||||
|
)
|
||||||
|
|
||||||
def __format_prediction(self, prediction):
|
def __format_prediction(self, prediction):
|
||||||
if isinstance(prediction, int):
|
if isinstance(prediction, int):
|
||||||
return self._classes[prediction]
|
return self._classes[prediction]
|
||||||
|
|
||||||
elif isinstance(prediction, dict):
|
elif isinstance(prediction, dict):
|
||||||
self.__validate_prediction_format(prediction)
|
self.__validate_dict_prediction_format(prediction)
|
||||||
return {self._classes[cls_idx] for cls_idx, prob in prediction.items()}
|
return {self._classes[cls_idx] for cls_idx, prob in prediction.items()}
|
||||||
|
|
||||||
|
elif isinstance(prediction, np.ndarray):
|
||||||
|
self.__validate_array_prediction_format(prediction)
|
||||||
|
return dict(zip(self._classes, prediction))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,4 @@ class EstimatorAdapter:
|
|||||||
self.estimator = estimator
|
self.estimator = estimator
|
||||||
|
|
||||||
def predict(self, batch):
|
def predict(self, batch):
|
||||||
return self.estimator.predict(batch)
|
return self.estimator(batch)
|
||||||
|
|
||||||
def predict_proba(self, batch):
|
|
||||||
return self.estimator.predict_proba(batch)
|
|
||||||
|
|||||||
@ -19,4 +19,4 @@ class UnexpectedPredictionFormat(ValueError):
|
|||||||
|
|
||||||
|
|
||||||
class IncorrectInstantiation(RuntimeError):
|
class IncorrectInstantiation(RuntimeError):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -21,7 +21,7 @@ class ExtractorClassifier:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
predictions = self.classifier(images)
|
predictions = self.classifier(images)
|
||||||
responses = ({"label": lbl, **mdt} for lbl, mdt in zip(predictions, metadata))
|
responses = ({"prediction": prd, **mdt} for prd, mdt in zip(predictions, metadata))
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
def __call__(self, obj) -> Iterable[ImageMetadataPair]:
|
def __call__(self, obj) -> Iterable[ImageMetadataPair]:
|
||||||
|
|||||||
@ -6,11 +6,14 @@ class PredictionModelHandle:
|
|||||||
|
|
||||||
def __init__(self, model_handle):
|
def __init__(self, model_handle):
|
||||||
self.__model_handle = model_handle
|
self.__model_handle = model_handle
|
||||||
|
self.__predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
|
||||||
|
self.__predict_proba = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
|
||||||
|
|
||||||
def predict(self, *args, **kwargs):
|
def predict(self, *args, **kwargs):
|
||||||
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict)
|
return self.__predict(*args, **kwargs)
|
||||||
return predict(*args, **kwargs)
|
|
||||||
|
|
||||||
def predict_proba(self, *args, **kwargs):
|
def predict_proba(self, *args, **kwargs):
|
||||||
predict = rcompose(self.__model_handle.prep_images, self.__model_handle.model.predict_proba)
|
return self.__predict_proba(*args, **kwargs)
|
||||||
return predict(*args, **kwargs)
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.predict_proba(*args, **kwargs)
|
||||||
|
|||||||
@ -52,6 +52,9 @@ class EstimatorMock:
|
|||||||
def predict(batch):
|
def predict(batch):
|
||||||
return [None for _ in batch]
|
return [None for _ in batch]
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
return self.predict(batch)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch):
|
def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch):
|
||||||
|
|||||||
@ -11,5 +11,5 @@ from image_prediction.extractor_classifier.extractor_classifier import Extractor
|
|||||||
def test_extractor_classifier(image_extractor, image_classifier, images, expected_predictions):
|
def test_extractor_classifier(image_extractor, image_classifier, images, expected_predictions):
|
||||||
extractor_classifier = ExtractorClassifier(image_extractor, image_classifier)
|
extractor_classifier = ExtractorClassifier(image_extractor, image_classifier)
|
||||||
results = extractor_classifier(images)
|
results = extractor_classifier(images)
|
||||||
labels = list(map(itemgetter("label"), results))
|
labels = list(map(itemgetter("prediction"), results))
|
||||||
assert labels == expected_predictions
|
assert labels == expected_predictions
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user