add predicting as subprocess, add workaround for keras not working if the model was loaded in different process

This commit is contained in:
Julius Unverfehrt 2022-03-21 13:53:40 +01:00
parent 530de2ff89
commit 4c9e6c38bd
2 changed files with 28 additions and 3 deletions

View File

@ -1,3 +1,4 @@
import multiprocessing
from typing import Callable from typing import Callable
from flask import Flask, request, jsonify from flask import Flask, request, jsonify
@ -25,11 +26,33 @@ def make_prediction_server(predict_fn: Callable):
@app.route("/", methods=["POST"]) @app.route("/", methods=["POST"])
def predict(): def predict():
pdf = request.data
def predict_fn_wrapper(pdf, return_dict):
return_dict["result"] = predict_fn(pdf)
def process():
# Tensorflow does not free RAM. Workaround is running model in process.
# https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution
pdf = request.data
manager = multiprocessing.Manager()
return_dict = manager.dict()
p = multiprocessing.Process(
target=predict_fn_wrapper,
args=(
pdf,
return_dict,
),
)
p.start()
p.join()
try:
return dict(return_dict)["result"]
except KeyError:
raise
logger.debug("Running predictor on document...") logger.debug("Running predictor on document...")
try: try:
predictions = predict_fn(pdf) predictions = process()
response = jsonify(predictions) response = jsonify(predictions)
logger.info("Analysis completed.") logger.info("Analysis completed.")
return response return response

View File

@ -14,11 +14,13 @@ logger = get_logger()
def main(): def main():
def predict(pdf): def predict(pdf):
# Keras model.predict stalls when model was loaded in different process
# https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python
predictor = Predictor()
predictions, metadata = predictor.predict_pdf(pdf, verbose=CONFIG.service.progressbar) predictions, metadata = predictor.predict_pdf(pdf, verbose=CONFIG.service.progressbar)
response = build_response(predictions, metadata) response = build_response(predictions, metadata)
return response return response
predictor = Predictor()
logger.info("Predictor ready.") logger.info("Predictor ready.")
prediction_server = make_prediction_server(predict) prediction_server = make_prediction_server(predict)