diff --git a/image_prediction/flask.py b/image_prediction/flask.py new file mode 100644 index 0000000..03ed026 --- /dev/null +++ b/image_prediction/flask.py @@ -0,0 +1,45 @@ +import logging +from typing import Callable + +from flask import Flask, request, jsonify + +from image_prediction.config import CONFIG + +logger = logging.getLogger(__name__) +logger.setLevel(CONFIG.service.logging_level) + + +def make_prediction_server(predict_fn: Callable): + + app = Flask(__name__) + + @app.route("/ready", methods=["GET"]) + def ready(): + resp = jsonify("OK") + resp.status_code = 200 + return resp + + @app.route("/health", methods=["GET"]) + def healthy(): + resp = jsonify("OK") + resp.status_code = 200 + return resp + + @app.route("/", methods=["POST"]) + def predict(): + pdf = request.data + + logger.debug("Running predictor on document...") + try: + predictions = predict_fn(pdf) + response = jsonify(predictions) + logger.info("Analysis completed.") + return response + except Exception as err: + logger.error("Analysis failed.") + logger.exception(err) + response = jsonify("Analysis failed.") + response.status_code = 500 + return response + + return app diff --git a/image_prediction/locations.py b/image_prediction/locations.py index 9c67dc0..4bf8d4b 100644 --- a/image_prediction/locations.py +++ b/image_prediction/locations.py @@ -2,12 +2,8 @@ from os import path MODULE_DIR = path.dirname(path.abspath(__file__)) PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR) -REPO_ROOT_DIR = path.dirname(path.dirname(PACKAGE_ROOT_DIR)) - -DOCKER_COMPOSE_FILE = path.join(REPO_ROOT_DIR, "docker-compose.yaml") CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml") -LOG_FILE = "/tmp/log.log" DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data") MLRUNS_DIR = path.join(DATA_DIR, "mlruns") diff --git a/image_prediction/predictor.py b/image_prediction/predictor.py index 4450e1a..0c2376b 100644 --- a/image_prediction/predictor.py +++ b/image_prediction/predictor.py @@ -7,6 +7,7 @@ import numpy as np from image_prediction.config import CONFIG from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS +from image_prediction.utils import temporary_pdf_file from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader @@ -88,29 +89,34 @@ class Predictor: return predictions if probabilities else classes + def predict_pdf(self, pdf): + with temporary_pdf_file(pdf) as pdf_path: + image_metadata_pairs = self.__extract_image_metadata_pairs(pdf_path) + return self.__predict_images(image_metadata_pairs) -def extract_image_metadata_pairs(pdf_path: str, **kwargs): - def image_is_large_enough(metadata: dict): - x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata) + def __predict_images(self, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size): + def process_chunk(chunk): + images, metadata = zip(*chunk) + predictions = self.predict(images, probabilities=True) + return predictions, metadata - return abs(x1 - x2) > 2 and abs(y1 - y2) > 2 + def predict(image_metadata_pair_generator): + chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size) + return map(chain.from_iterable, zip(*map(process_chunk, chunks))) - yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs) + try: + predictions, metadata = predict(image_metadata_pairs) + return predictions, metadata + except ValueError: + return [], [] -def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size): - def process_chunk(chunk): - images, metadata = zip(*chunk) - predictions = predictor.predict(images, probabilities=True) - return predictions, metadata + @staticmethod + def __extract_image_metadata_pairs(pdf_path: str, **kwargs): + def image_is_large_enough(metadata: dict): + x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata) - def predict(image_metadata_pair_generator): - chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size) - return map(chain.from_iterable, zip(*map(process_chunk, chunks))) + return abs(x1 - x2) > 2 and abs(y1 - y2) > 2 - try: - predictions, metadata = predict(image_metadata_pairs) - return predictions, metadata + yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs) - except ValueError: - return [], [] diff --git a/image_prediction/response.py b/image_prediction/response.py index 2fc3225..b5cdb7a 100644 --- a/image_prediction/response.py +++ b/image_prediction/response.py @@ -1,11 +1,10 @@ """Defines functions for constructing service responses.""" +import math from itertools import starmap from operator import itemgetter -import numpy as np - from image_prediction.config import CONFIG @@ -15,8 +14,8 @@ def build_response(predictions: list, metadata: list) -> list: def build_image_info(prediction: dict, metadata: dict) -> dict: def compute_geometric_quotient(): - page_area_sqrt = np.sqrt(abs(page_width * page_height)) - image_area_sqrt = np.sqrt(abs(x2 - x1) * abs(y2 - y1)) + page_area_sqrt = math.sqrt(abs(page_width * page_height)) + image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1)) return image_area_sqrt / page_area_sqrt page_width, page_height, x1, x2, y1, y2, width, height = itemgetter( @@ -36,7 +35,7 @@ def build_image_info(prediction: dict, metadata: dict) -> dict: min_confidence_breached = bool(max(prediction["probabilities"].values()) < CONFIG.filters.min_confidence) prediction["label"] = prediction.pop("class") # "class" as field name causes problem for Java objectmapper - prediction["probabilities"] = {klass: np.round(prob, 6) for klass, prob in prediction["probabilities"].items()} + prediction["probabilities"] = {klass: round(prob, 6) for klass, prob in prediction["probabilities"].items()} image_info = { "classification": prediction, diff --git a/image_prediction/utils.py b/image_prediction/utils.py new file mode 100644 index 0000000..59f56e7 --- /dev/null +++ b/image_prediction/utils.py @@ -0,0 +1,9 @@ +import tempfile +from contextlib import contextmanager + + +@contextmanager +def temporary_pdf_file(pdf: bytes): + with tempfile.NamedTemporaryFile() as f: + f.write(pdf) + yield f.name