diff --git a/src/serve.py b/src/serve.py index bc6bae2..ba6cb91 100644 --- a/src/serve.py +++ b/src/serve.py @@ -1,57 +1,29 @@ import logging -import tempfile -from flask import Flask, request, jsonify from waitress import serve from image_prediction.config import CONFIG -from image_prediction.predictor import Predictor, extract_image_metadata_pairs, classify_images +from image_prediction.flask import make_prediction_server +from image_prediction.predictor import Predictor from image_prediction.response import build_response +logger = logging.getLogger(__name__) +logger.setLevel(CONFIG.service.logging_level) + def main(): - predictor = Predictor() - logging.info("Predictor ready.") - - app = Flask(__name__) - - @app.route("/ready", methods=["GET"]) - def ready(): - resp = jsonify("OK") - resp.status_code = 200 - return resp - - @app.route("/health", methods=["GET"]) - def healthy(): - resp = jsonify("OK") - resp.status_code = 200 - return resp - - @app.route("/", methods=["POST"]) - def predict(): - pdf = request.data - - logging.debug("Running predictor on document...") - with tempfile.NamedTemporaryFile() as tmp_file: - tmp_file.write(pdf) - image_metadata_pairs = extract_image_metadata_pairs(tmp_file.name) - try: - predictions, metadata = classify_images(predictor, image_metadata_pairs) - except Exception as err: - logging.warning("Analysis failed.") - logging.exception(err) - response = jsonify("Analysis failed.") - response.status_code = 500 - return response - logging.debug(f"Found images in document.") - - response = jsonify(build_response(list(predictions), list(metadata))) - - logging.info("Analysis completed.") + def predict(pdf): + predictions, metadata = predictor.predict_pdf(pdf) + response = build_response(predictions, metadata) return response - run_prediction_server(app, mode=CONFIG.webserver.mode) + predictor = Predictor() + logger.info("Predictor ready.") + + prediction_server = make_prediction_server(predict) + + run_prediction_server(prediction_server, mode=CONFIG.webserver.mode) def run_prediction_server(app, mode="development"): @@ -68,5 +40,7 @@ if __name__ == "__main__": logging.getLogger("urllib3").setLevel(logging.ERROR) logging.getLogger("werkzeug").setLevel(logging.ERROR) logging.getLogger("waitress").setLevel(logging.ERROR) + logging.getLogger("PIL").setLevel(logging.ERROR) + logging.getLogger("h5py").setLevel(logging.ERROR) main()