From 1581880ec6cf1680397ff9a427f76d9ac7e194bf Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Thu, 31 Mar 2022 19:38:35 +0200 Subject: [PATCH] added updated version of serve.py --- src/serve.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 src/serve.py diff --git a/src/serve.py b/src/serve.py new file mode 100644 index 0000000..5001b98 --- /dev/null +++ b/src/serve.py @@ -0,0 +1,48 @@ +import logging + +from waitress import serve + +from image_prediction.config import CONFIG +from image_prediction.default_objects import load_pipeline +from image_prediction.flask import make_prediction_server +from image_prediction.utils import get_logger +from image_prediction.utils.banner import show_banner + +logger = get_logger() + + +def main(): + def predict(pdf): + # Keras service_estimator.predict stalls when service_estimator was loaded in different process; + # therefore, we re-load the model (part of the pipeline) every time we process a new document. + # https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python + logger.debug("Loading pipeline...") + pipeline = load_pipeline(verbose=CONFIG.service.verbose) + logger.debug("Running pipeline...") + return pipeline(pdf) + + prediction_server = make_prediction_server(predict) + + run_prediction_server(prediction_server, mode=CONFIG.webserver.mode) + + +def run_prediction_server(app, mode="development"): + if mode == "development": + app.run(host=CONFIG.webserver.host, port=CONFIG.webserver.port, debug=True) + elif mode == "production": + serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port, _quiet=False) + + +if __name__ == "__main__": + logging_level = CONFIG.service.logging_level + logging.basicConfig(level=logging_level) + logging.getLogger("flask").setLevel(logging.ERROR) + logging.getLogger("urllib3").setLevel(logging.ERROR) + logging.getLogger("werkzeug").setLevel(logging.ERROR) + logging.getLogger("waitress").setLevel(logging.ERROR) + logging.getLogger("PIL").setLevel(logging.ERROR) + logging.getLogger("h5py").setLevel(logging.ERROR) + + show_banner() + + main()