From a9d60654f55765b4faeb8901d8db8ac7ecf6fb33 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Wed, 16 Mar 2022 15:07:30 +0100 Subject: [PATCH] Pull request #3: Refactoring Merge in RR/image-prediction from refactoring to master Squashed commit of the following: commit fc4e2efac113f2e307fdbc091e0a4f4e3e5729d3 Author: Matthias Bisping Date: Wed Mar 16 14:21:05 2022 +0100 applied black commit 3baabf5bc0b04347af85dafbb056f134258d9715 Author: Matthias Bisping Date: Wed Mar 16 14:20:30 2022 +0100 added banner commit 30e871cfdc79d0ff2e0c26d1b858e55ab1b0453f Author: Matthias Bisping Date: Wed Mar 16 14:02:26 2022 +0100 rename logger commit d76fefd3ff0c4425defca4db218ce4a84c6053f3 Author: Matthias Bisping Date: Wed Mar 16 14:00:39 2022 +0100 logger refactoring commit 0e004cbd21ab00b8804901952405fa870bf48e9c Author: Matthias Bisping Date: Wed Mar 16 14:00:08 2022 +0100 logger refactoring commit 49e113f8d85d7973b73f664779906a1347d1522d Author: Matthias Bisping Date: Wed Mar 16 13:25:08 2022 +0100 refactoring commit 7ec3d52e155cb83bed8804d2fee4f5bdf54fb59b Author: Matthias Bisping Date: Wed Mar 16 13:21:52 2022 +0100 applied black commit 06ea0be8aa9344e11b9d92fd526f2b73061bc736 Author: Matthias Bisping Date: Wed Mar 16 13:21:20 2022 +0100 refactoring --- config.yaml | 2 +- image_prediction/flask.py | 43 ++++++++++++++++++++++ image_prediction/locations.py | 4 --- image_prediction/predictor.py | 48 ++++++++++++++----------- image_prediction/response.py | 9 +++-- image_prediction/utils.py | 68 +++++++++++++++++++++++++++++++++++ src/serve.py | 60 ++++++++++--------------------- 7 files changed, 161 insertions(+), 73 deletions(-) create mode 100644 image_prediction/flask.py create mode 100644 image_prediction/utils.py diff --git a/config.yaml b/config.yaml index 77c5141..dbc5b87 100644 --- a/config.yaml +++ b/config.yaml @@ -5,6 +5,7 @@ webserver: service: logging_level: $LOGGING_LEVEL_ROOT|DEBUG # Logging level for service logger + progressbar: True # Whether a progress bar over the pages of a document is displayed while processing batch_size: $BATCH_SIZE|32 # Number of images in memory simultaneously verbose: $VERBOSE|True # Service prints document processing progress to stdout run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the model from @@ -25,4 +26,3 @@ filters: max: $MAX_IMAGE_FORMAT|10 # Maximum permissible min_confidence: $MIN_CONFIDENCE|0.5 # Minimum permissible prediction confidence - diff --git a/image_prediction/flask.py b/image_prediction/flask.py new file mode 100644 index 0000000..34f8a29 --- /dev/null +++ b/image_prediction/flask.py @@ -0,0 +1,43 @@ +from typing import Callable + +from flask import Flask, request, jsonify + +from image_prediction.utils import get_logger + +logger = get_logger() + + +def make_prediction_server(predict_fn: Callable): + + app = Flask(__name__) + + @app.route("/ready", methods=["GET"]) + def ready(): + resp = jsonify("OK") + resp.status_code = 200 + return resp + + @app.route("/health", methods=["GET"]) + def healthy(): + resp = jsonify("OK") + resp.status_code = 200 + return resp + + @app.route("/", methods=["POST"]) + def predict(): + pdf = request.data + + logger.debug("Running predictor on document...") + try: + predictions = predict_fn(pdf) + response = jsonify(predictions) + logger.info("Analysis completed.") + return response + except Exception as err: + logger.error("Analysis failed.") + logger.exception(err) + response = jsonify("Analysis failed.") + response.status_code = 500 + return response + + return app diff --git a/image_prediction/locations.py b/image_prediction/locations.py index 9c67dc0..4bf8d4b 100644 --- a/image_prediction/locations.py +++ b/image_prediction/locations.py @@ -2,12 +2,8 @@ from os import path MODULE_DIR = path.dirname(path.abspath(__file__)) PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR) -REPO_ROOT_DIR = path.dirname(path.dirname(PACKAGE_ROOT_DIR)) - -DOCKER_COMPOSE_FILE = path.join(REPO_ROOT_DIR, "docker-compose.yaml") CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml") -LOG_FILE = "/tmp/log.log" DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data") MLRUNS_DIR = path.join(DATA_DIR, "mlruns") diff --git a/image_prediction/predictor.py b/image_prediction/predictor.py index 4450e1a..8e83f2c 100644 --- a/image_prediction/predictor.py +++ b/image_prediction/predictor.py @@ -1,4 +1,3 @@ -import logging from itertools import chain from operator import itemgetter from typing import List, Dict, Iterable @@ -7,11 +6,14 @@ import numpy as np from image_prediction.config import CONFIG from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS +from image_prediction.utils import temporary_pdf_file, get_logger from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader from incl.redai_image.redai.redai.utils.shared import chunk_iterable +logger = get_logger() + class Predictor: """`ModelHandle` wrapper. Forwards to wrapped model handle for prediction and produces structured output that is @@ -36,7 +38,7 @@ class Predictor: self.classes_readable = np.array(self.model_handle.classes) self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]] except Exception as e: - logging.info(f"Service estimator initialization failed: {e}") + logger.info(f"Service estimator initialization failed: {e}") def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]: """Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to @@ -88,29 +90,33 @@ class Predictor: return predictions if probabilities else classes + def predict_pdf(self, pdf, verbose=False): + with temporary_pdf_file(pdf) as pdf_path: + image_metadata_pairs = self.__extract_image_metadata_pairs(pdf_path, verbose=verbose) + return self.__predict_images(image_metadata_pairs) -def extract_image_metadata_pairs(pdf_path: str, **kwargs): - def image_is_large_enough(metadata: dict): - x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata) + def __predict_images(self, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size): + def process_chunk(chunk): + images, metadata = zip(*chunk) + predictions = self.predict(images, probabilities=True) + return predictions, metadata - return abs(x1 - x2) > 2 and abs(y1 - y2) > 2 + def predict(image_metadata_pair_generator): + chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size) + return map(chain.from_iterable, zip(*map(process_chunk, chunks))) - yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs) + try: + predictions, metadata = predict(image_metadata_pairs) + return predictions, metadata + except ValueError: + return [], [] -def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size): - def process_chunk(chunk): - images, metadata = zip(*chunk) - predictions = predictor.predict(images, probabilities=True) - return predictions, metadata + @staticmethod + def __extract_image_metadata_pairs(pdf_path: str, **kwargs): + def image_is_large_enough(metadata: dict): + x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata) - def predict(image_metadata_pair_generator): - chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size) - return map(chain.from_iterable, zip(*map(process_chunk, chunks))) + return abs(x1 - x2) > 2 and abs(y1 - y2) > 2 - try: - predictions, metadata = predict(image_metadata_pairs) - return predictions, metadata - - except ValueError: - return [], [] + yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs) diff --git a/image_prediction/response.py b/image_prediction/response.py index 2fc3225..b5cdb7a 100644 --- a/image_prediction/response.py +++ b/image_prediction/response.py @@ -1,11 +1,10 @@ """Defines functions for constructing service responses.""" +import math from itertools import starmap from operator import itemgetter -import numpy as np - from image_prediction.config import CONFIG @@ -15,8 +14,8 @@ def build_response(predictions: list, metadata: list) -> list: def build_image_info(prediction: dict, metadata: dict) -> dict: def compute_geometric_quotient(): - page_area_sqrt = np.sqrt(abs(page_width * page_height)) - image_area_sqrt = np.sqrt(abs(x2 - x1) * abs(y2 - y1)) + page_area_sqrt = math.sqrt(abs(page_width * page_height)) + image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1)) return image_area_sqrt / page_area_sqrt page_width, page_height, x1, x2, y1, y2, width, height = itemgetter( @@ -36,7 +35,7 @@ def build_image_info(prediction: dict, metadata: dict) -> dict: min_confidence_breached = bool(max(prediction["probabilities"].values()) < CONFIG.filters.min_confidence) prediction["label"] = prediction.pop("class") # "class" as field name causes problem for Java objectmapper - prediction["probabilities"] = {klass: np.round(prob, 6) for klass, prob in prediction["probabilities"].items()} + prediction["probabilities"] = {klass: round(prob, 6) for klass, prob in prediction["probabilities"].items()} image_info = { "classification": prediction, diff --git a/image_prediction/utils.py b/image_prediction/utils.py new file mode 100644 index 0000000..15badca --- /dev/null +++ b/image_prediction/utils.py @@ -0,0 +1,68 @@ +import logging +import tempfile +from contextlib import contextmanager + +from image_prediction.config import CONFIG + + +@contextmanager +def temporary_pdf_file(pdf: bytes): + with tempfile.NamedTemporaryFile() as f: + f.write(pdf) + yield f.name + + +def make_logger_getter(): + + logger = logging.getLogger("imclf") + logger.propagate = False + + handler = logging.StreamHandler() + handler.setLevel(CONFIG.service.logging_level) + + log_format = "[%(levelname)s]: %(message)s" + formatter = logging.Formatter(log_format) + + handler.setFormatter(formatter) + logger.addHandler(handler) + + def get_logger(): + return logger + + return get_logger + + +get_logger = make_logger_getter() + + +def show_banner(): + banner = ''' + ..... . ... .. + .d88888Neu. 'L xH88"`~ .x8X x .d88" oec : + F""""*8888888F .. . : :8888 .f"8888Hf 5888R @88888 + * `"*88*" .888: x888 x888. :8888> X8L ^""` '888R 8"*88% + -.... ue=:. ~`8888~'888X`?888f` X8888 X888h 888R 8b. + :88N ` X888 888X '888> 88888 !88888. 888R u888888> + 9888L X888 888X '888> 88888 %88888 888R 8888R + uzu. `8888L X888 888X '888> 88888 '> `8888> 888R 8888P +,""888i ?8888 X888 888X '888> `8888L % ?888 ! 888R *888> +4 9888L %888> "*88%""*88" '888!` `8888 `-*"" / .888B . 4888 +' '8888 '88% `~ " `"` "888. :" ^*888% '888 + "*8Nu.z*" `""***~"` "% 88R + 88> + 48 + '8 + ''' + + logger = logging.getLogger(__name__) + logger.propagate = False + + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + + formatter = logging.Formatter("") + + handler.setFormatter(formatter) + logger.addHandler(handler) + + logger.info(banner) diff --git a/src/serve.py b/src/serve.py index bc6bae2..f44b632 100644 --- a/src/serve.py +++ b/src/serve.py @@ -1,57 +1,29 @@ import logging -import tempfile -from flask import Flask, request, jsonify from waitress import serve from image_prediction.config import CONFIG -from image_prediction.predictor import Predictor, extract_image_metadata_pairs, classify_images +from image_prediction.flask import make_prediction_server +from image_prediction.predictor import Predictor from image_prediction.response import build_response +from image_prediction.utils import get_logger, show_banner + +logger = get_logger() def main(): - predictor = Predictor() - logging.info("Predictor ready.") - - app = Flask(__name__) - - @app.route("/ready", methods=["GET"]) - def ready(): - resp = jsonify("OK") - resp.status_code = 200 - return resp - - @app.route("/health", methods=["GET"]) - def healthy(): - resp = jsonify("OK") - resp.status_code = 200 - return resp - - @app.route("/", methods=["POST"]) - def predict(): - pdf = request.data - - logging.debug("Running predictor on document...") - with tempfile.NamedTemporaryFile() as tmp_file: - tmp_file.write(pdf) - image_metadata_pairs = extract_image_metadata_pairs(tmp_file.name) - try: - predictions, metadata = classify_images(predictor, image_metadata_pairs) - except Exception as err: - logging.warning("Analysis failed.") - logging.exception(err) - response = jsonify("Analysis failed.") - response.status_code = 500 - return response - logging.debug(f"Found images in document.") - - response = jsonify(build_response(list(predictions), list(metadata))) - - logging.info("Analysis completed.") + def predict(pdf): + predictions, metadata = predictor.predict_pdf(pdf, verbose=CONFIG.service.progressbar) + response = build_response(predictions, metadata) return response - run_prediction_server(app, mode=CONFIG.webserver.mode) + predictor = Predictor() + logger.info("Predictor ready.") + + prediction_server = make_prediction_server(predict) + + run_prediction_server(prediction_server, mode=CONFIG.webserver.mode) def run_prediction_server(app, mode="development"): @@ -68,5 +40,9 @@ if __name__ == "__main__": logging.getLogger("urllib3").setLevel(logging.ERROR) logging.getLogger("werkzeug").setLevel(logging.ERROR) logging.getLogger("waitress").setLevel(logging.ERROR) + logging.getLogger("PIL").setLevel(logging.ERROR) + logging.getLogger("h5py").setLevel(logging.ERROR) + + show_banner() main()