Compare commits
3 Commits
master
...
refactorin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
49e113f8d8 | ||
|
|
7ec3d52e15 | ||
|
|
06ea0be8aa |
45
image_prediction/flask.py
Normal file
45
image_prediction/flask.py
Normal file
@ -0,0 +1,45 @@
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
from flask import Flask, request, jsonify
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(CONFIG.service.logging_level)
|
||||
|
||||
|
||||
def make_prediction_server(predict_fn: Callable):
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/ready", methods=["GET"])
|
||||
def ready():
|
||||
resp = jsonify("OK")
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
@app.route("/health", methods=["GET"])
|
||||
def healthy():
|
||||
resp = jsonify("OK")
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
def predict():
|
||||
pdf = request.data
|
||||
|
||||
logger.debug("Running predictor on document...")
|
||||
try:
|
||||
predictions = predict_fn(pdf)
|
||||
response = jsonify(predictions)
|
||||
logger.info("Analysis completed.")
|
||||
return response
|
||||
except Exception as err:
|
||||
logger.error("Analysis failed.")
|
||||
logger.exception(err)
|
||||
response = jsonify("Analysis failed.")
|
||||
response.status_code = 500
|
||||
return response
|
||||
|
||||
return app
|
||||
@ -2,12 +2,8 @@ from os import path
|
||||
|
||||
MODULE_DIR = path.dirname(path.abspath(__file__))
|
||||
PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR)
|
||||
REPO_ROOT_DIR = path.dirname(path.dirname(PACKAGE_ROOT_DIR))
|
||||
|
||||
DOCKER_COMPOSE_FILE = path.join(REPO_ROOT_DIR, "docker-compose.yaml")
|
||||
|
||||
CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml")
|
||||
LOG_FILE = "/tmp/log.log"
|
||||
|
||||
DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data")
|
||||
MLRUNS_DIR = path.join(DATA_DIR, "mlruns")
|
||||
|
||||
@ -7,6 +7,7 @@ import numpy as np
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS
|
||||
from image_prediction.utils import temporary_pdf_file
|
||||
from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle
|
||||
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
|
||||
from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader
|
||||
@ -88,29 +89,33 @@ class Predictor:
|
||||
|
||||
return predictions if probabilities else classes
|
||||
|
||||
def predict_pdf(self, pdf):
|
||||
with temporary_pdf_file(pdf) as pdf_path:
|
||||
image_metadata_pairs = self.__extract_image_metadata_pairs(pdf_path)
|
||||
return self.__predict_images(image_metadata_pairs)
|
||||
|
||||
def extract_image_metadata_pairs(pdf_path: str, **kwargs):
|
||||
def image_is_large_enough(metadata: dict):
|
||||
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
|
||||
def __predict_images(self, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
|
||||
def process_chunk(chunk):
|
||||
images, metadata = zip(*chunk)
|
||||
predictions = self.predict(images, probabilities=True)
|
||||
return predictions, metadata
|
||||
|
||||
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
|
||||
def predict(image_metadata_pair_generator):
|
||||
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
|
||||
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
|
||||
|
||||
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
|
||||
try:
|
||||
predictions, metadata = predict(image_metadata_pairs)
|
||||
return predictions, metadata
|
||||
|
||||
except ValueError:
|
||||
return [], []
|
||||
|
||||
def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
|
||||
def process_chunk(chunk):
|
||||
images, metadata = zip(*chunk)
|
||||
predictions = predictor.predict(images, probabilities=True)
|
||||
return predictions, metadata
|
||||
@staticmethod
|
||||
def __extract_image_metadata_pairs(pdf_path: str, **kwargs):
|
||||
def image_is_large_enough(metadata: dict):
|
||||
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
|
||||
|
||||
def predict(image_metadata_pair_generator):
|
||||
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
|
||||
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
|
||||
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
|
||||
|
||||
try:
|
||||
predictions, metadata = predict(image_metadata_pairs)
|
||||
return predictions, metadata
|
||||
|
||||
except ValueError:
|
||||
return [], []
|
||||
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
"""Defines functions for constructing service responses."""
|
||||
|
||||
|
||||
import math
|
||||
from itertools import starmap
|
||||
from operator import itemgetter
|
||||
|
||||
import numpy as np
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
|
||||
|
||||
@ -15,8 +14,8 @@ def build_response(predictions: list, metadata: list) -> list:
|
||||
|
||||
def build_image_info(prediction: dict, metadata: dict) -> dict:
|
||||
def compute_geometric_quotient():
|
||||
page_area_sqrt = np.sqrt(abs(page_width * page_height))
|
||||
image_area_sqrt = np.sqrt(abs(x2 - x1) * abs(y2 - y1))
|
||||
page_area_sqrt = math.sqrt(abs(page_width * page_height))
|
||||
image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1))
|
||||
return image_area_sqrt / page_area_sqrt
|
||||
|
||||
page_width, page_height, x1, x2, y1, y2, width, height = itemgetter(
|
||||
@ -36,7 +35,7 @@ def build_image_info(prediction: dict, metadata: dict) -> dict:
|
||||
|
||||
min_confidence_breached = bool(max(prediction["probabilities"].values()) < CONFIG.filters.min_confidence)
|
||||
prediction["label"] = prediction.pop("class") # "class" as field name causes problem for Java objectmapper
|
||||
prediction["probabilities"] = {klass: np.round(prob, 6) for klass, prob in prediction["probabilities"].items()}
|
||||
prediction["probabilities"] = {klass: round(prob, 6) for klass, prob in prediction["probabilities"].items()}
|
||||
|
||||
image_info = {
|
||||
"classification": prediction,
|
||||
|
||||
9
image_prediction/utils.py
Normal file
9
image_prediction/utils.py
Normal file
@ -0,0 +1,9 @@
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_pdf_file(pdf: bytes):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
f.write(pdf)
|
||||
yield f.name
|
||||
58
src/serve.py
58
src/serve.py
@ -1,57 +1,29 @@
|
||||
import logging
|
||||
import tempfile
|
||||
|
||||
from flask import Flask, request, jsonify
|
||||
from waitress import serve
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.predictor import Predictor, extract_image_metadata_pairs, classify_images
|
||||
from image_prediction.flask import make_prediction_server
|
||||
from image_prediction.predictor import Predictor
|
||||
from image_prediction.response import build_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(CONFIG.service.logging_level)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
predictor = Predictor()
|
||||
logging.info("Predictor ready.")
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/ready", methods=["GET"])
|
||||
def ready():
|
||||
resp = jsonify("OK")
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
@app.route("/health", methods=["GET"])
|
||||
def healthy():
|
||||
resp = jsonify("OK")
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
def predict():
|
||||
pdf = request.data
|
||||
|
||||
logging.debug("Running predictor on document...")
|
||||
with tempfile.NamedTemporaryFile() as tmp_file:
|
||||
tmp_file.write(pdf)
|
||||
image_metadata_pairs = extract_image_metadata_pairs(tmp_file.name)
|
||||
try:
|
||||
predictions, metadata = classify_images(predictor, image_metadata_pairs)
|
||||
except Exception as err:
|
||||
logging.warning("Analysis failed.")
|
||||
logging.exception(err)
|
||||
response = jsonify("Analysis failed.")
|
||||
response.status_code = 500
|
||||
return response
|
||||
logging.debug(f"Found images in document.")
|
||||
|
||||
response = jsonify(build_response(list(predictions), list(metadata)))
|
||||
|
||||
logging.info("Analysis completed.")
|
||||
def predict(pdf):
|
||||
predictions, metadata = predictor.predict_pdf(pdf)
|
||||
response = build_response(predictions, metadata)
|
||||
return response
|
||||
|
||||
run_prediction_server(app, mode=CONFIG.webserver.mode)
|
||||
predictor = Predictor()
|
||||
logger.info("Predictor ready.")
|
||||
|
||||
prediction_server = make_prediction_server(predict)
|
||||
|
||||
run_prediction_server(prediction_server, mode=CONFIG.webserver.mode)
|
||||
|
||||
|
||||
def run_prediction_server(app, mode="development"):
|
||||
@ -68,5 +40,7 @@ if __name__ == "__main__":
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
||||
logging.getLogger("waitress").setLevel(logging.ERROR)
|
||||
logging.getLogger("PIL").setLevel(logging.ERROR)
|
||||
logging.getLogger("h5py").setLevel(logging.ERROR)
|
||||
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user