refactoring
This commit is contained in:
parent
1501653673
commit
cb00aed62c
@ -14,6 +14,29 @@ def run_prediction_server(app, host, port):
|
||||
serve(app, host=host, port=port, _quiet=False)
|
||||
|
||||
|
||||
def run_in_process(func):
|
||||
p = multiprocessing.Process(target=func)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
|
||||
def wrap_in_process(func_to_wrap):
|
||||
def build_function_and_run_in_process(*args, **kwargs):
|
||||
def func():
|
||||
try:
|
||||
result = func_to_wrap(*args, **kwargs)
|
||||
return_dict["result"] = result
|
||||
except:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
manager = multiprocessing.Manager()
|
||||
return_dict = manager.dict()
|
||||
run_in_process(func)
|
||||
return return_dict.get("result", None)
|
||||
|
||||
return build_function_and_run_in_process
|
||||
|
||||
|
||||
def make_prediction_server(predict_fn: Callable):
|
||||
app = Flask(__name__)
|
||||
|
||||
@ -29,32 +52,27 @@ def make_prediction_server(predict_fn: Callable):
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
def __failure():
|
||||
response = jsonify("Analysis failed")
|
||||
response.status_code = 500
|
||||
return response
|
||||
|
||||
@app.route("/predict", methods=["POST"])
|
||||
def predict():
|
||||
def predict_fn_wrapper(pdf, return_dict):
|
||||
return_dict["result"] = predict_fn(pdf)
|
||||
|
||||
def process():
|
||||
# Tensorflow does not free RAM. Workaround is running service_estimator in process.
|
||||
# https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution
|
||||
pdf = request.data
|
||||
manager = multiprocessing.Manager()
|
||||
return_dict = manager.dict()
|
||||
p = multiprocessing.Process(target=predict_fn_wrapper, args=(pdf, return_dict))
|
||||
p.start()
|
||||
p.join()
|
||||
return return_dict["result"]
|
||||
# Tensorflow does not free RAM. Workaround: Run prediction function (which instantiates a model) in sub-process.
|
||||
# See: https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution
|
||||
predict_fn_wrapped = wrap_in_process(predict_fn)
|
||||
|
||||
logger.info("Analysing document...")
|
||||
try:
|
||||
predictions = process()
|
||||
logger.info("Analysing...")
|
||||
predictions = predict_fn_wrapped(request.data)
|
||||
|
||||
if predictions:
|
||||
response = jsonify(predictions)
|
||||
logger.debug("Analysis completed.")
|
||||
return response
|
||||
except Exception:
|
||||
logger.exception(f"Analysis failed\n{traceback.format_exc()}")
|
||||
response = jsonify("Analysis failed")
|
||||
response.status_code = 500
|
||||
logger.info("Analysis completed.")
|
||||
return response
|
||||
else:
|
||||
logger.error("Analysis failed.")
|
||||
return __failure()
|
||||
|
||||
return app
|
||||
|
||||
@ -12,11 +12,10 @@ logger.setLevel(logging.CRITICAL + 1)
|
||||
|
||||
def predict_fn(x: bytes):
|
||||
x = int(x.decode())
|
||||
match x:
|
||||
case 42:
|
||||
return True
|
||||
case _:
|
||||
raise Exception
|
||||
if x == 42:
|
||||
return True
|
||||
else:
|
||||
raise Exception("intentional test exception")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user