import multiprocessing import traceback from typing import Callable from flask import Flask, request, jsonify from image_prediction.utils import get_logger logger = get_logger() def run_in_process(func): p = multiprocessing.Process(target=func) p.start() p.join() def wrap_in_process(func_to_wrap): def build_function_and_run_in_process(*args, **kwargs): def func(): try: result = func_to_wrap(*args, **kwargs) return_dict["result"] = result except: logger.error(traceback.format_exc()) manager = multiprocessing.Manager() return_dict = manager.dict() run_in_process(func) return return_dict.get("result", None) return build_function_and_run_in_process def make_prediction_server(predict_fn: Callable): app = Flask(__name__) @app.route("/ready", methods=["GET"]) def ready(): resp = jsonify("OK") resp.status_code = 200 return resp @app.route("/health", methods=["GET"]) def healthy(): resp = jsonify("OK") resp.status_code = 200 return resp def __failure(): response = jsonify("Analysis failed") response.status_code = 500 return response @app.route("/predict", methods=["POST"]) @app.route("/", methods=["POST"]) def predict(): # Tensorflow does not free RAM. Workaround: Run prediction function (which instantiates a model) in sub-process. # See: https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution predict_fn_wrapped = wrap_in_process(predict_fn) logger.info("Analysing...") predictions = predict_fn_wrapped(request.data) if predictions: response = jsonify(predictions) logger.info("Analysis completed.") return response else: logger.error("Analysis failed.") return __failure() return app