import multiprocessing import traceback from typing import Callable from flask import Flask, request, jsonify from waitress import serve from image_prediction.utils import get_logger logger = get_logger() def run_prediction_server(app, host, port): serve(app, host=host, port=port, _quiet=False) def make_prediction_server(predict_fn: Callable): app = Flask(__name__) @app.route("/ready", methods=["GET"]) def ready(): resp = jsonify("OK") resp.status_code = 200 return resp @app.route("/health", methods=["GET"]) def healthy(): resp = jsonify("OK") resp.status_code = 200 return resp @app.route("/predict", methods=["POST"]) def predict(): def predict_fn_wrapper(pdf, return_dict): return_dict["result"] = predict_fn(pdf) def process(): # Tensorflow does not free RAM. Workaround is running service_estimator in process. # https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution pdf = request.data manager = multiprocessing.Manager() return_dict = manager.dict() p = multiprocessing.Process(target=predict_fn_wrapper, args=(pdf, return_dict)) p.start() p.join() return return_dict["result"] logger.info("Analysing document...") try: predictions = process() response = jsonify(predictions) logger.debug("Analysis completed.") return response except Exception: logger.exception(f"Analysis failed\n{traceback.format_exc()}") response = jsonify("Analysis failed") response.status_code = 500 return response return app