2022-03-25 11:42:31 +01:00

66 lines
1.8 KiB
Python

import multiprocessing
from typing import Callable
from flask import Flask, request, jsonify
from image_prediction.utils import get_logger
logger = get_logger()
def make_prediction_server(predict_fn: Callable):
app = Flask(__name__)
@app.route("/ready", methods=["GET"])
def ready():
resp = jsonify("OK")
resp.status_code = 200
return resp
@app.route("/health", methods=["GET"])
def healthy():
resp = jsonify("OK")
resp.status_code = 200
return resp
@app.route("/", methods=["POST"])
def predict():
def predict_fn_wrapper(pdf, return_dict):
return_dict["result"] = predict_fn(pdf)
def process():
# Tensorflow does not free RAM. Workaround is running service_estimator in process.
# https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution
pdf = request.data
manager = multiprocessing.Manager()
return_dict = manager.dict()
p = multiprocessing.Process(
target=predict_fn_wrapper,
args=(
pdf,
return_dict,
),
)
p.start()
p.join()
try:
return dict(return_dict)["result"]
except KeyError:
raise
logger.debug("Running predictor on document...")
try:
predictions = process()
response = jsonify(predictions)
logger.info("Analysis completed.")
return response
except Exception as err:
logger.error("Analysis failed.")
logger.exception(err)
response = jsonify("Analysis failed.")
response.status_code = 500
return response
return app