refactoring
This commit is contained in:
parent
7ec3d52e15
commit
49e113f8d8
58
src/serve.py
58
src/serve.py
@ -1,57 +1,29 @@
|
||||
import logging
|
||||
import tempfile
|
||||
|
||||
from flask import Flask, request, jsonify
|
||||
from waitress import serve
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.predictor import Predictor, extract_image_metadata_pairs, classify_images
|
||||
from image_prediction.flask import make_prediction_server
|
||||
from image_prediction.predictor import Predictor
|
||||
from image_prediction.response import build_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(CONFIG.service.logging_level)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
predictor = Predictor()
|
||||
logging.info("Predictor ready.")
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/ready", methods=["GET"])
|
||||
def ready():
|
||||
resp = jsonify("OK")
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
@app.route("/health", methods=["GET"])
|
||||
def healthy():
|
||||
resp = jsonify("OK")
|
||||
resp.status_code = 200
|
||||
return resp
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
def predict():
|
||||
pdf = request.data
|
||||
|
||||
logging.debug("Running predictor on document...")
|
||||
with tempfile.NamedTemporaryFile() as tmp_file:
|
||||
tmp_file.write(pdf)
|
||||
image_metadata_pairs = extract_image_metadata_pairs(tmp_file.name)
|
||||
try:
|
||||
predictions, metadata = classify_images(predictor, image_metadata_pairs)
|
||||
except Exception as err:
|
||||
logging.warning("Analysis failed.")
|
||||
logging.exception(err)
|
||||
response = jsonify("Analysis failed.")
|
||||
response.status_code = 500
|
||||
return response
|
||||
logging.debug(f"Found images in document.")
|
||||
|
||||
response = jsonify(build_response(list(predictions), list(metadata)))
|
||||
|
||||
logging.info("Analysis completed.")
|
||||
def predict(pdf):
|
||||
predictions, metadata = predictor.predict_pdf(pdf)
|
||||
response = build_response(predictions, metadata)
|
||||
return response
|
||||
|
||||
run_prediction_server(app, mode=CONFIG.webserver.mode)
|
||||
predictor = Predictor()
|
||||
logger.info("Predictor ready.")
|
||||
|
||||
prediction_server = make_prediction_server(predict)
|
||||
|
||||
run_prediction_server(prediction_server, mode=CONFIG.webserver.mode)
|
||||
|
||||
|
||||
def run_prediction_server(app, mode="development"):
|
||||
@ -68,5 +40,7 @@ if __name__ == "__main__":
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
||||
logging.getLogger("waitress").setLevel(logging.ERROR)
|
||||
logging.getLogger("PIL").setLevel(logging.ERROR)
|
||||
logging.getLogger("h5py").setLevel(logging.ERROR)
|
||||
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user