coverage increased for flask tests

This commit is contained in:
Matthias Bisping 2022-04-02 00:16:01 +02:00
parent b4b929b65f
commit 1501653673
3 changed files with 44 additions and 31 deletions

View File

@ -1,4 +1,5 @@
import multiprocessing
import traceback
from typing import Callable
from flask import Flask, request, jsonify
@ -39,19 +40,10 @@ def make_prediction_server(predict_fn: Callable):
pdf = request.data
manager = multiprocessing.Manager()
return_dict = manager.dict()
p = multiprocessing.Process(
target=predict_fn_wrapper,
args=(
pdf,
return_dict,
),
)
p = multiprocessing.Process(target=predict_fn_wrapper, args=(pdf, return_dict))
p.start()
p.join()
try:
return dict(return_dict)["result"]
except KeyError:
raise
return return_dict["result"]
logger.info("Analysing document...")
try:
@ -59,10 +51,9 @@ def make_prediction_server(predict_fn: Callable):
response = jsonify(predictions)
logger.debug("Analysis completed.")
return response
except Exception as err:
logger.error("Analysis failed.")
logger.exception(err)
response = jsonify("Analysis failed.")
except Exception:
logger.exception(f"Analysis failed\n{traceback.format_exc()}")
response = jsonify("Analysis failed")
response.status_code = 500
return response

View File

@ -4,19 +4,20 @@ from image_prediction.config import CONFIG
def make_logger_getter():
logger = logging.getLogger("imclf")
logger.propagate = False
handler = logging.StreamHandler()
handler.setLevel(CONFIG.service.logging_level)
log_format = "[%(levelname)s]: %(message)s"
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)
logger.addHandler(handler)
def get_logger():
logger = logging.getLogger("imclf")
logger.propagate = False
handler = logging.StreamHandler()
handler.setLevel(CONFIG.service.logging_level)
log_format = "[%(levelname)s]: %(message)s"
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
return get_logger

View File

@ -1,13 +1,27 @@
import json
import logging
import pytest
from image_prediction.flask import make_prediction_server
from image_prediction.utils import get_logger
logger = get_logger()
logger.setLevel(logging.CRITICAL + 1)
def predict_fn(x: bytes):
x = int(x.decode())
match x:
case 42:
return True
case _:
raise Exception
@pytest.fixture
def server():
server = make_prediction_server(lambda _: 42)
server = make_prediction_server(predict_fn)
server.config.update({"TESTING": True})
return server
@ -17,16 +31,23 @@ def client(server):
return server.test_client()
def test_server_predict(client):
response = client.post("/predict")
assert json.loads(response.data) == 42
def test_server_predict_success(client):
response = client.post("/predict", data="42")
assert json.loads(response.data)
def test_server_predict_failure(client):
response = client.post("/predict", data="13")
assert response.status_code == 500
def test_server_health_check(client):
response = client.get("/health")
response = client.get("/ready")
assert response.status_code == 200
assert response.json == "OK"
def test_server_ready_check(client):
response = client.get("/health")
assert response.status_code == 200
assert response.json == "OK"