Compare commits

...

2 Commits

Author SHA1 Message Date
Matthias Bisping
7ec3d52e15 applied black 2022-03-16 13:21:52 +01:00
Matthias Bisping
06ea0be8aa refactoring 2022-03-16 13:21:20 +01:00
5 changed files with 82 additions and 28 deletions

45
image_prediction/flask.py Normal file
View File

@ -0,0 +1,45 @@
import logging
from typing import Callable
from flask import Flask, request, jsonify
from image_prediction.config import CONFIG
logger = logging.getLogger(__name__)
logger.setLevel(CONFIG.service.logging_level)
def make_prediction_server(predict_fn: Callable):
app = Flask(__name__)
@app.route("/ready", methods=["GET"])
def ready():
resp = jsonify("OK")
resp.status_code = 200
return resp
@app.route("/health", methods=["GET"])
def healthy():
resp = jsonify("OK")
resp.status_code = 200
return resp
@app.route("/", methods=["POST"])
def predict():
pdf = request.data
logger.debug("Running predictor on document...")
try:
predictions = predict_fn(pdf)
response = jsonify(predictions)
logger.info("Analysis completed.")
return response
except Exception as err:
logger.error("Analysis failed.")
logger.exception(err)
response = jsonify("Analysis failed.")
response.status_code = 500
return response
return app

View File

@ -2,12 +2,8 @@ from os import path
MODULE_DIR = path.dirname(path.abspath(__file__)) MODULE_DIR = path.dirname(path.abspath(__file__))
PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR) PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR)
REPO_ROOT_DIR = path.dirname(path.dirname(PACKAGE_ROOT_DIR))
DOCKER_COMPOSE_FILE = path.join(REPO_ROOT_DIR, "docker-compose.yaml")
CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml") CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml")
LOG_FILE = "/tmp/log.log"
DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data") DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data")
MLRUNS_DIR = path.join(DATA_DIR, "mlruns") MLRUNS_DIR = path.join(DATA_DIR, "mlruns")

View File

@ -7,6 +7,7 @@ import numpy as np
from image_prediction.config import CONFIG from image_prediction.config import CONFIG
from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS
from image_prediction.utils import temporary_pdf_file
from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader
@ -88,29 +89,33 @@ class Predictor:
return predictions if probabilities else classes return predictions if probabilities else classes
def predict_pdf(self, pdf):
with temporary_pdf_file(pdf) as pdf_path:
image_metadata_pairs = self.__extract_image_metadata_pairs(pdf_path)
return self.__predict_images(image_metadata_pairs)
def extract_image_metadata_pairs(pdf_path: str, **kwargs): def __predict_images(self, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
def image_is_large_enough(metadata: dict): def process_chunk(chunk):
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata) images, metadata = zip(*chunk)
predictions = self.predict(images, probabilities=True)
return predictions, metadata
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2 def predict(image_metadata_pair_generator):
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs) try:
predictions, metadata = predict(image_metadata_pairs)
return predictions, metadata
except ValueError:
return [], []
def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size): @staticmethod
def process_chunk(chunk): def __extract_image_metadata_pairs(pdf_path: str, **kwargs):
images, metadata = zip(*chunk) def image_is_large_enough(metadata: dict):
predictions = predictor.predict(images, probabilities=True) x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
return predictions, metadata
def predict(image_metadata_pair_generator): return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
try: yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
predictions, metadata = predict(image_metadata_pairs)
return predictions, metadata
except ValueError:
return [], []

View File

@ -1,11 +1,10 @@
"""Defines functions for constructing service responses.""" """Defines functions for constructing service responses."""
import math
from itertools import starmap from itertools import starmap
from operator import itemgetter from operator import itemgetter
import numpy as np
from image_prediction.config import CONFIG from image_prediction.config import CONFIG
@ -15,8 +14,8 @@ def build_response(predictions: list, metadata: list) -> list:
def build_image_info(prediction: dict, metadata: dict) -> dict: def build_image_info(prediction: dict, metadata: dict) -> dict:
def compute_geometric_quotient(): def compute_geometric_quotient():
page_area_sqrt = np.sqrt(abs(page_width * page_height)) page_area_sqrt = math.sqrt(abs(page_width * page_height))
image_area_sqrt = np.sqrt(abs(x2 - x1) * abs(y2 - y1)) image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1))
return image_area_sqrt / page_area_sqrt return image_area_sqrt / page_area_sqrt
page_width, page_height, x1, x2, y1, y2, width, height = itemgetter( page_width, page_height, x1, x2, y1, y2, width, height = itemgetter(
@ -36,7 +35,7 @@ def build_image_info(prediction: dict, metadata: dict) -> dict:
min_confidence_breached = bool(max(prediction["probabilities"].values()) < CONFIG.filters.min_confidence) min_confidence_breached = bool(max(prediction["probabilities"].values()) < CONFIG.filters.min_confidence)
prediction["label"] = prediction.pop("class") # "class" as field name causes problem for Java objectmapper prediction["label"] = prediction.pop("class") # "class" as field name causes problem for Java objectmapper
prediction["probabilities"] = {klass: np.round(prob, 6) for klass, prob in prediction["probabilities"].items()} prediction["probabilities"] = {klass: round(prob, 6) for klass, prob in prediction["probabilities"].items()}
image_info = { image_info = {
"classification": prediction, "classification": prediction,

View File

@ -0,0 +1,9 @@
import tempfile
from contextlib import contextmanager
@contextmanager
def temporary_pdf_file(pdf: bytes):
with tempfile.NamedTemporaryFile() as f:
f.write(pdf)
yield f.name