import logging from itertools import chain from operator import itemgetter from typing import List, Dict, Iterable import numpy as np from image_prediction.config import CONFIG from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader from incl.redai_image.redai.redai.utils.shared import chunk_iterable class Predictor: """`ModelHandle` wrapper. Forwards to wrapped model handle for prediction and produces structured output that is interpretable independently of the wrapped model (e.g. with regard to a .classes_ attribute). """ def __init__(self, model_handle: ModelHandle = None): """Initializes a ServiceEstimator. Args: model_handle: ModelHandle object to forward to for prediction. By default, a model handle is loaded from the mlflow database via CONFIG.service.run_id. """ try: if model_handle is None: reader = MlflowModelReader(run_id=CONFIG.service.run_id, mlruns_dir=MLRUNS_DIR) self.model_handle = reader.get_model_handle(BASE_WEIGHTS) else: self.model_handle = model_handle self.classes = self.model_handle.model.classes_ self.classes_readable = np.array(self.model_handle.classes) self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]] except Exception as e: logging.info(f"Service estimator initialization failed: {e}") def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]: """Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to probabilities. Args: probs: probability matrix (items x classes) Returns: list of mappings from classes to probabilities. """ classes = np.argmax(probs, axis=1) classes = self.classes[classes] classes_readable = [self.model_handle.classes[c] for c in classes] return classes_readable def predict(self, images: List, probabilities: bool = False, **kwargs): """Gathers predictions for list of images. Assigns each image a class and optionally a probability distribution over all classes. Args: images (List[PIL.Image]) : Images to gather predictions for. probabilities: Whether to return dictionaries of the following form instead of strings: { "class": predicted class, "probabilities": { "class 1" : class 1 probability, "class 2" : class 2 probability, ... } } Returns: By default the return value is a list of classes (meaningful class name strings). Alternatively a list of dictionaries with an additional probability field for estimated class probabilities per image can be returned. """ X = self.model_handle.prep_images(list(images)) probs_per_item = self.model_handle.model.predict_proba(X, **kwargs).astype(float) classes = self.__make_predictions_human_readable(probs_per_item) class2prob_per_item = [dict(zip(self.classes_readable_aligned, probs)) for probs in probs_per_item] class2prob_per_item = [ dict(sorted(c2p.items(), key=itemgetter(1), reverse=True)) for c2p in class2prob_per_item ] predictions = [{"class": c, "probabilities": c2p} for c, c2p in zip(classes, class2prob_per_item)] return predictions if probabilities else classes def extract_image_metadata_pairs(pdf_path: str, **kwargs): def image_is_large_enough(metadata: dict): x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata) return abs(x1 - x2) > 2 and abs(y1 - y2) > 2 yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs) def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size): def process_chunk(chunk): images, metadata = zip(*chunk) predictions = predictor.predict(images, probabilities=True) return predictions, metadata def predict(image_metadata_pair_generator): chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size) return map(chain.from_iterable, zip(*map(process_chunk, chunks))) try: predictions, metadata = predict(image_metadata_pairs) return predictions, metadata except ValueError: return [], []