Merge in RR/image-prediction from refactoring to master
Squashed commit of the following:
commit fc4e2efac113f2e307fdbc091e0a4f4e3e5729d3
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 14:21:05 2022 +0100
applied black
commit 3baabf5bc0b04347af85dafbb056f134258d9715
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 14:20:30 2022 +0100
added banner
commit 30e871cfdc79d0ff2e0c26d1b858e55ab1b0453f
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 14:02:26 2022 +0100
rename logger
commit d76fefd3ff0c4425defca4db218ce4a84c6053f3
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 14:00:39 2022 +0100
logger refactoring
commit 0e004cbd21ab00b8804901952405fa870bf48e9c
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 14:00:08 2022 +0100
logger refactoring
commit 49e113f8d85d7973b73f664779906a1347d1522d
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 13:25:08 2022 +0100
refactoring
commit 7ec3d52e155cb83bed8804d2fee4f5bdf54fb59b
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 13:21:52 2022 +0100
applied black
commit 06ea0be8aa9344e11b9d92fd526f2b73061bc736
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 13:21:20 2022 +0100
refactoring
123 lines
5.2 KiB
Python
123 lines
5.2 KiB
Python
from itertools import chain
|
|
from operator import itemgetter
|
|
from typing import List, Dict, Iterable
|
|
|
|
import numpy as np
|
|
|
|
from image_prediction.config import CONFIG
|
|
from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS
|
|
from image_prediction.utils import temporary_pdf_file, get_logger
|
|
from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle
|
|
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
|
|
from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader
|
|
from incl.redai_image.redai.redai.utils.shared import chunk_iterable
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class Predictor:
|
|
"""`ModelHandle` wrapper. Forwards to wrapped model handle for prediction and produces structured output that is
|
|
interpretable independently of the wrapped model (e.g. with regard to a .classes_ attribute).
|
|
"""
|
|
|
|
def __init__(self, model_handle: ModelHandle = None):
|
|
"""Initializes a ServiceEstimator.
|
|
|
|
Args:
|
|
model_handle: ModelHandle object to forward to for prediction. By default, a model handle is loaded from the
|
|
mlflow database via CONFIG.service.run_id.
|
|
"""
|
|
try:
|
|
if model_handle is None:
|
|
reader = MlflowModelReader(run_id=CONFIG.service.run_id, mlruns_dir=MLRUNS_DIR)
|
|
self.model_handle = reader.get_model_handle(BASE_WEIGHTS)
|
|
else:
|
|
self.model_handle = model_handle
|
|
|
|
self.classes = self.model_handle.model.classes_
|
|
self.classes_readable = np.array(self.model_handle.classes)
|
|
self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]]
|
|
except Exception as e:
|
|
logger.info(f"Service estimator initialization failed: {e}")
|
|
|
|
def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]:
|
|
"""Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to
|
|
probabilities.
|
|
|
|
Args:
|
|
probs: probability matrix (items x classes)
|
|
|
|
Returns:
|
|
list of mappings from classes to probabilities.
|
|
"""
|
|
classes = np.argmax(probs, axis=1)
|
|
classes = self.classes[classes]
|
|
classes_readable = [self.model_handle.classes[c] for c in classes]
|
|
return classes_readable
|
|
|
|
def predict(self, images: List, probabilities: bool = False, **kwargs):
|
|
"""Gathers predictions for list of images. Assigns each image a class and optionally a probability distribution
|
|
over all classes.
|
|
|
|
Args:
|
|
images (List[PIL.Image]) : Images to gather predictions for.
|
|
probabilities: Whether to return dictionaries of the following form instead of strings:
|
|
{
|
|
"class": predicted class,
|
|
"probabilities": {
|
|
"class 1" : class 1 probability,
|
|
"class 2" : class 2 probability,
|
|
...
|
|
}
|
|
}
|
|
|
|
Returns:
|
|
By default the return value is a list of classes (meaningful class name strings). Alternatively a list of
|
|
dictionaries with an additional probability field for estimated class probabilities per image can be
|
|
returned.
|
|
"""
|
|
X = self.model_handle.prep_images(list(images))
|
|
|
|
probs_per_item = self.model_handle.model.predict_proba(X, **kwargs).astype(float)
|
|
classes = self.__make_predictions_human_readable(probs_per_item)
|
|
|
|
class2prob_per_item = [dict(zip(self.classes_readable_aligned, probs)) for probs in probs_per_item]
|
|
class2prob_per_item = [
|
|
dict(sorted(c2p.items(), key=itemgetter(1), reverse=True)) for c2p in class2prob_per_item
|
|
]
|
|
|
|
predictions = [{"class": c, "probabilities": c2p} for c, c2p in zip(classes, class2prob_per_item)]
|
|
|
|
return predictions if probabilities else classes
|
|
|
|
def predict_pdf(self, pdf, verbose=False):
|
|
with temporary_pdf_file(pdf) as pdf_path:
|
|
image_metadata_pairs = self.__extract_image_metadata_pairs(pdf_path, verbose=verbose)
|
|
return self.__predict_images(image_metadata_pairs)
|
|
|
|
def __predict_images(self, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
|
|
def process_chunk(chunk):
|
|
images, metadata = zip(*chunk)
|
|
predictions = self.predict(images, probabilities=True)
|
|
return predictions, metadata
|
|
|
|
def predict(image_metadata_pair_generator):
|
|
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
|
|
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
|
|
|
|
try:
|
|
predictions, metadata = predict(image_metadata_pairs)
|
|
return predictions, metadata
|
|
|
|
except ValueError:
|
|
return [], []
|
|
|
|
@staticmethod
|
|
def __extract_image_metadata_pairs(pdf_path: str, **kwargs):
|
|
def image_is_large_enough(metadata: dict):
|
|
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
|
|
|
|
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
|
|
|
|
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
|