89 lines
3.7 KiB
Python

import logging
from operator import itemgetter
from typing import List, Dict
import numpy as np
from image_prediction.config import CONFIG
from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS
from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle
from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader
class Predictor:
"""`ModelHandle` wrapper. Forwards to wrapped model handle for prediction and produces structured output that is
interpretable independently of the wrapped model (e.g. with regard to a .classes_ attribute).
"""
def __init__(self, model_handle: ModelHandle = None):
"""Initializes a ServiceEstimator.
Args:
model_handle: ModelHandle object to forward to for prediction. By default, a model handle is loaded from the
mlflow database via CONFIG.service.run_id.
"""
try:
if model_handle is None:
reader = MlflowModelReader(
run_id=CONFIG.service.run_id, mlruns_dir=MLRUNS_DIR
)
self.model_handle = reader.get_model_handle(BASE_WEIGHTS)
else:
self.model_handle = model_handle
self.classes = self.model_handle.model.classes_
self.classes_readable = np.array(self.model_handle.classes)
self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]]
except Exception as e:
logging.info(f"Service estimator initialization failed: {e}")
def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]:
"""Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to
probabilities.
Args:
probs: probability matrix (items x classes)
Returns:
list of mappings from classes to probabilities.
"""
classes = np.argmax(probs, axis=1)
classes = self.classes[classes]
classes_readable = [self.model_handle.classes[c] for c in classes]
return classes_readable
def predict(self, images: List, probabilities: bool = False, **kwargs):
"""Gathers predictions for list of images. Assigns each image a class and optionally a probability distribution
over all classes.
Args:
images (List[PIL.Image]) : Images to gather predictions for.
probabilities: Whether to return dictionaries of the following form instead of strings:
{
"class": predicted class,
"probabilities": {
"class 1" : class 1 probability,
"class 2" : class 2 probability,
...
}
}
Returns:
By default the return value is a list of classes (meaningful class name strings). Alternatively a list of
dictionaries with an additional probability field for estimated class probabilities per image can be
returned.
"""
X = self.model_handle.prep_images(list(images))
probs_per_item = self.model_handle.model.predict_proba(X, **kwargs).astype(float)
classes = self.__make_predictions_human_readable(probs_per_item)
class2prob_per_item = [dict(zip(self.classes_readable_aligned, probs)) for probs in probs_per_item]
class2prob_per_item = [
dict(sorted(c2p.items(), key=itemgetter(1), reverse=True)) for c2p in class2prob_per_item
]
predictions = [{"class": c, "probabilities": c2p} for c, c2p in zip(classes, class2prob_per_item)]
return predictions if probabilities else classes