From 4c9e6c38bdcea7d81008bf9dfcfcdd19d199da6a Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Mon, 21 Mar 2022 13:53:40 +0100 Subject: [PATCH] add predicting as subprocess, add workaround for keras not working if the model was loaded in different process --- image_prediction/flask.py | 27 +++++++++++++++++++++++++-- src/serve.py | 4 +++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/image_prediction/flask.py b/image_prediction/flask.py index 34f8a29..11aa356 100644 --- a/image_prediction/flask.py +++ b/image_prediction/flask.py @@ -1,3 +1,4 @@ +import multiprocessing from typing import Callable from flask import Flask, request, jsonify @@ -25,11 +26,33 @@ def make_prediction_server(predict_fn: Callable): @app.route("/", methods=["POST"]) def predict(): - pdf = request.data + + def predict_fn_wrapper(pdf, return_dict): + return_dict["result"] = predict_fn(pdf) + + def process(): + # Tensorflow does not free RAM. Workaround is running model in process. + # https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution + pdf = request.data + manager = multiprocessing.Manager() + return_dict = manager.dict() + p = multiprocessing.Process( + target=predict_fn_wrapper, + args=( + pdf, + return_dict, + ), + ) + p.start() + p.join() + try: + return dict(return_dict)["result"] + except KeyError: + raise logger.debug("Running predictor on document...") try: - predictions = predict_fn(pdf) + predictions = process() response = jsonify(predictions) logger.info("Analysis completed.") return response diff --git a/src/serve.py b/src/serve.py index f44b632..af7a133 100644 --- a/src/serve.py +++ b/src/serve.py @@ -14,11 +14,13 @@ logger = get_logger() def main(): def predict(pdf): + # Keras model.predict stalls when model was loaded in different process + # https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python + predictor = Predictor() predictions, metadata = predictor.predict_pdf(pdf, verbose=CONFIG.service.progressbar) response = build_response(predictions, metadata) return response - predictor = Predictor() logger.info("Predictor ready.") prediction_server = make_prediction_server(predict)