refactoring
This commit is contained in:
parent
1501653673
commit
cb00aed62c
@ -14,6 +14,29 @@ def run_prediction_server(app, host, port):
|
|||||||
serve(app, host=host, port=port, _quiet=False)
|
serve(app, host=host, port=port, _quiet=False)
|
||||||
|
|
||||||
|
|
||||||
|
def run_in_process(func):
|
||||||
|
p = multiprocessing.Process(target=func)
|
||||||
|
p.start()
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_in_process(func_to_wrap):
|
||||||
|
def build_function_and_run_in_process(*args, **kwargs):
|
||||||
|
def func():
|
||||||
|
try:
|
||||||
|
result = func_to_wrap(*args, **kwargs)
|
||||||
|
return_dict["result"] = result
|
||||||
|
except:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
manager = multiprocessing.Manager()
|
||||||
|
return_dict = manager.dict()
|
||||||
|
run_in_process(func)
|
||||||
|
return return_dict.get("result", None)
|
||||||
|
|
||||||
|
return build_function_and_run_in_process
|
||||||
|
|
||||||
|
|
||||||
def make_prediction_server(predict_fn: Callable):
|
def make_prediction_server(predict_fn: Callable):
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
@ -29,32 +52,27 @@ def make_prediction_server(predict_fn: Callable):
|
|||||||
resp.status_code = 200
|
resp.status_code = 200
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
def __failure():
|
||||||
|
response = jsonify("Analysis failed")
|
||||||
|
response.status_code = 500
|
||||||
|
return response
|
||||||
|
|
||||||
@app.route("/predict", methods=["POST"])
|
@app.route("/predict", methods=["POST"])
|
||||||
def predict():
|
def predict():
|
||||||
def predict_fn_wrapper(pdf, return_dict):
|
|
||||||
return_dict["result"] = predict_fn(pdf)
|
|
||||||
|
|
||||||
def process():
|
# Tensorflow does not free RAM. Workaround: Run prediction function (which instantiates a model) in sub-process.
|
||||||
# Tensorflow does not free RAM. Workaround is running service_estimator in process.
|
# See: https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution
|
||||||
# https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution
|
predict_fn_wrapped = wrap_in_process(predict_fn)
|
||||||
pdf = request.data
|
|
||||||
manager = multiprocessing.Manager()
|
|
||||||
return_dict = manager.dict()
|
|
||||||
p = multiprocessing.Process(target=predict_fn_wrapper, args=(pdf, return_dict))
|
|
||||||
p.start()
|
|
||||||
p.join()
|
|
||||||
return return_dict["result"]
|
|
||||||
|
|
||||||
logger.info("Analysing document...")
|
logger.info("Analysing...")
|
||||||
try:
|
predictions = predict_fn_wrapped(request.data)
|
||||||
predictions = process()
|
|
||||||
|
if predictions:
|
||||||
response = jsonify(predictions)
|
response = jsonify(predictions)
|
||||||
logger.debug("Analysis completed.")
|
logger.info("Analysis completed.")
|
||||||
return response
|
|
||||||
except Exception:
|
|
||||||
logger.exception(f"Analysis failed\n{traceback.format_exc()}")
|
|
||||||
response = jsonify("Analysis failed")
|
|
||||||
response.status_code = 500
|
|
||||||
return response
|
return response
|
||||||
|
else:
|
||||||
|
logger.error("Analysis failed.")
|
||||||
|
return __failure()
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|||||||
@ -12,11 +12,10 @@ logger.setLevel(logging.CRITICAL + 1)
|
|||||||
|
|
||||||
def predict_fn(x: bytes):
|
def predict_fn(x: bytes):
|
||||||
x = int(x.decode())
|
x = int(x.decode())
|
||||||
match x:
|
if x == 42:
|
||||||
case 42:
|
return True
|
||||||
return True
|
else:
|
||||||
case _:
|
raise Exception("intentional test exception")
|
||||||
raise Exception
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user