removed obsolete code
This commit is contained in:
parent
692e72b3b2
commit
9cda65ad41
@ -1,122 +0,0 @@
|
|||||||
from itertools import chain
|
|
||||||
from operator import itemgetter
|
|
||||||
from typing import List, Dict, Iterable
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from image_prediction.config import CONFIG
|
|
||||||
from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS
|
|
||||||
from image_prediction.utils import temporary_pdf_file, get_logger
|
|
||||||
from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle
|
|
||||||
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
|
|
||||||
from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader
|
|
||||||
from incl.redai_image.redai.redai.utils.shared import chunk_iterable
|
|
||||||
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
class Predictor:
|
|
||||||
"""`ModelHandle` wrapper. Forwards to wrapped service_estimator handle for prediction and produces structured output that is
|
|
||||||
interpretable independently of the wrapped service_estimator (e.g. with regard to a .classes_ attribute).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model_handle: ModelHandle = None):
|
|
||||||
"""Initializes a ServiceEstimator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_handle: ModelHandle object to forward to for prediction. By default, a service_estimator handle is loaded from the
|
|
||||||
mlflow database via CONFIG.service.run_id.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if model_handle is None:
|
|
||||||
reader = MlflowModelReader(run_id=CONFIG.service.run_id, mlruns_dir=MLRUNS_DIR)
|
|
||||||
self.model_handle = reader.get_model_handle(BASE_WEIGHTS)
|
|
||||||
else:
|
|
||||||
self.model_handle = model_handle
|
|
||||||
|
|
||||||
self.classes = self.model_handle.model.classes_
|
|
||||||
self.classes_readable = np.array(self.model_handle.classes)
|
|
||||||
self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]]
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(f"Service estimator initialization failed: {e}")
|
|
||||||
|
|
||||||
def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]:
|
|
||||||
"""Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to
|
|
||||||
probabilities.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
probs: probability matrix (items x classes)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list of mappings from classes to probabilities.
|
|
||||||
"""
|
|
||||||
classes = np.argmax(probs, axis=1)
|
|
||||||
classes = self.classes[classes]
|
|
||||||
classes_readable = [self.model_handle.classes[c] for c in classes]
|
|
||||||
return classes_readable
|
|
||||||
|
|
||||||
def predict(self, images: List, probabilities: bool = False, **kwargs):
|
|
||||||
"""Gathers predictions for list of images. Assigns each image a class and optionally a probability distribution
|
|
||||||
over all classes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (List[PIL.Image]) : Images to gather predictions for.
|
|
||||||
probabilities: Whether to return dictionaries of the following form instead of strings:
|
|
||||||
{
|
|
||||||
"class": predicted class,
|
|
||||||
"probabilities": {
|
|
||||||
"class 1" : class 1 probability,
|
|
||||||
"class 2" : class 2 probability,
|
|
||||||
...
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
By default the return value is a list of classes (meaningful class name strings). Alternatively a list of
|
|
||||||
dictionaries with an additional probability field for estimated class probabilities per image can be
|
|
||||||
returned.
|
|
||||||
"""
|
|
||||||
X = self.model_handle.prep_images(list(images))
|
|
||||||
|
|
||||||
probs_per_item = self.model_handle.model.predict_proba(X, **kwargs).astype(float)
|
|
||||||
classes = self.__make_predictions_human_readable(probs_per_item)
|
|
||||||
|
|
||||||
class2prob_per_item = [dict(zip(self.classes_readable_aligned, probs)) for probs in probs_per_item]
|
|
||||||
class2prob_per_item = [
|
|
||||||
dict(sorted(c2p.items(), key=itemgetter(1), reverse=True)) for c2p in class2prob_per_item
|
|
||||||
]
|
|
||||||
|
|
||||||
predictions = [{"class": c, "probabilities": c2p} for c, c2p in zip(classes, class2prob_per_item)]
|
|
||||||
|
|
||||||
return predictions if probabilities else classes
|
|
||||||
|
|
||||||
def predict_pdf(self, pdf, verbose=False):
|
|
||||||
with temporary_pdf_file(pdf) as pdf_path:
|
|
||||||
image_metadata_pairs = self.__extract_image_metadata_pairs(pdf_path, verbose=verbose)
|
|
||||||
return self.__predict_images(image_metadata_pairs)
|
|
||||||
|
|
||||||
def __predict_images(self, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
|
|
||||||
def process_chunk(chunk):
|
|
||||||
images, metadata = zip(*chunk)
|
|
||||||
predictions = self.predict(images, probabilities=True)
|
|
||||||
return predictions, metadata
|
|
||||||
|
|
||||||
def predict(image_metadata_pair_generator):
|
|
||||||
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
|
|
||||||
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
|
|
||||||
|
|
||||||
try:
|
|
||||||
predictions, metadata = predict(image_metadata_pairs)
|
|
||||||
return predictions, metadata
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __extract_image_metadata_pairs(pdf_path: str, **kwargs):
|
|
||||||
def image_is_large_enough(metadata: dict):
|
|
||||||
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
|
|
||||||
|
|
||||||
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
|
|
||||||
|
|
||||||
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
|
|
||||||
@ -1,70 +0,0 @@
|
|||||||
"""Defines functions for constructing service responses."""
|
|
||||||
|
|
||||||
|
|
||||||
import math
|
|
||||||
from itertools import starmap
|
|
||||||
from operator import itemgetter
|
|
||||||
|
|
||||||
from image_prediction.config import CONFIG
|
|
||||||
|
|
||||||
|
|
||||||
def build_response(predictions: list, metadata: list) -> list:
|
|
||||||
return list(starmap(build_image_info, zip(predictions, metadata)))
|
|
||||||
|
|
||||||
|
|
||||||
def build_image_info(prediction: dict, metadata: dict) -> dict:
|
|
||||||
def compute_geometric_quotient():
|
|
||||||
page_area_sqrt = math.sqrt(abs(page_width * page_height))
|
|
||||||
image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1))
|
|
||||||
return image_area_sqrt / page_area_sqrt
|
|
||||||
|
|
||||||
page_width, page_height, x1, x2, y1, y2, width, height = itemgetter(
|
|
||||||
"page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height"
|
|
||||||
)(metadata)
|
|
||||||
|
|
||||||
quotient = compute_geometric_quotient()
|
|
||||||
|
|
||||||
min_image_to_page_quotient_breached = bool(quotient < CONFIG.filters.image_to_page_quotient.min)
|
|
||||||
max_image_to_page_quotient_breached = bool(quotient > CONFIG.filters.image_to_page_quotient.max)
|
|
||||||
min_image_width_to_height_quotient_breached = bool(
|
|
||||||
width / height < CONFIG.filters.image_width_to_height_quotient.min
|
|
||||||
)
|
|
||||||
max_image_width_to_height_quotient_breached = bool(
|
|
||||||
width / height > CONFIG.filters.image_width_to_height_quotient.max
|
|
||||||
)
|
|
||||||
|
|
||||||
min_confidence_breached = bool(max(prediction["probabilities"].values()) < CONFIG.filters.min_confidence)
|
|
||||||
prediction["label"] = prediction.pop("class") # "class" as field name causes problem for Java objectmapper
|
|
||||||
prediction["probabilities"] = {klass: round(prob, 6) for klass, prob in prediction["probabilities"].items()}
|
|
||||||
|
|
||||||
image_info = {
|
|
||||||
"classification": prediction,
|
|
||||||
"position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": metadata["page_idx"] + 1},
|
|
||||||
"geometry": {"width": width, "height": height},
|
|
||||||
"filters": {
|
|
||||||
"geometry": {
|
|
||||||
"imageSize": {
|
|
||||||
"quotient": quotient,
|
|
||||||
"tooLarge": max_image_to_page_quotient_breached,
|
|
||||||
"tooSmall": min_image_to_page_quotient_breached,
|
|
||||||
},
|
|
||||||
"imageFormat": {
|
|
||||||
"quotient": width / height,
|
|
||||||
"tooTall": min_image_width_to_height_quotient_breached,
|
|
||||||
"tooWide": max_image_width_to_height_quotient_breached,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"probability": {"unconfident": min_confidence_breached},
|
|
||||||
"allPassed": not any(
|
|
||||||
[
|
|
||||||
max_image_to_page_quotient_breached,
|
|
||||||
min_image_to_page_quotient_breached,
|
|
||||||
min_image_width_to_height_quotient_breached,
|
|
||||||
max_image_width_to_height_quotient_breached,
|
|
||||||
min_confidence_breached,
|
|
||||||
]
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return image_info
|
|
||||||
@ -1,49 +0,0 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
from waitress import serve
|
|
||||||
|
|
||||||
from image_prediction.config import CONFIG
|
|
||||||
from image_prediction.flask import make_prediction_server
|
|
||||||
from image_prediction.predictor import Predictor
|
|
||||||
from image_prediction.response import build_response
|
|
||||||
from image_prediction.utils import get_logger, show_banner
|
|
||||||
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
def predict(pdf):
|
|
||||||
# Keras service_estimator.predict stalls when service_estimator was loaded in different process
|
|
||||||
# https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python
|
|
||||||
predictor = Predictor()
|
|
||||||
predictions, metadata = predictor.predict_pdf(pdf, verbose=CONFIG.service.progressbar)
|
|
||||||
response = build_response(predictions, metadata)
|
|
||||||
return response
|
|
||||||
|
|
||||||
logger.info("Predictor ready.")
|
|
||||||
|
|
||||||
prediction_server = make_prediction_server(predict)
|
|
||||||
|
|
||||||
run_prediction_server(prediction_server, mode=CONFIG.webserver.mode)
|
|
||||||
|
|
||||||
|
|
||||||
def run_prediction_server(app, mode="development"):
|
|
||||||
if mode == "development":
|
|
||||||
app.run(host=CONFIG.webserver.host, port=CONFIG.webserver.port, debug=True)
|
|
||||||
elif mode == "production":
|
|
||||||
serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logging_level = CONFIG.service.logging_level
|
|
||||||
logging.basicConfig(level=logging_level)
|
|
||||||
logging.getLogger("flask").setLevel(logging.ERROR)
|
|
||||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
|
||||||
logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
|
||||||
logging.getLogger("waitress").setLevel(logging.ERROR)
|
|
||||||
logging.getLogger("PIL").setLevel(logging.ERROR)
|
|
||||||
logging.getLogger("h5py").setLevel(logging.ERROR)
|
|
||||||
|
|
||||||
show_banner()
|
|
||||||
|
|
||||||
main()
|
|
||||||
Loading…
x
Reference in New Issue
Block a user