coverage increased for flask tests
This commit is contained in:
parent
b4b929b65f
commit
1501653673
@ -1,4 +1,5 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import traceback
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from flask import Flask, request, jsonify
|
from flask import Flask, request, jsonify
|
||||||
@ -39,19 +40,10 @@ def make_prediction_server(predict_fn: Callable):
|
|||||||
pdf = request.data
|
pdf = request.data
|
||||||
manager = multiprocessing.Manager()
|
manager = multiprocessing.Manager()
|
||||||
return_dict = manager.dict()
|
return_dict = manager.dict()
|
||||||
p = multiprocessing.Process(
|
p = multiprocessing.Process(target=predict_fn_wrapper, args=(pdf, return_dict))
|
||||||
target=predict_fn_wrapper,
|
|
||||||
args=(
|
|
||||||
pdf,
|
|
||||||
return_dict,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
p.start()
|
p.start()
|
||||||
p.join()
|
p.join()
|
||||||
try:
|
return return_dict["result"]
|
||||||
return dict(return_dict)["result"]
|
|
||||||
except KeyError:
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info("Analysing document...")
|
logger.info("Analysing document...")
|
||||||
try:
|
try:
|
||||||
@ -59,10 +51,9 @@ def make_prediction_server(predict_fn: Callable):
|
|||||||
response = jsonify(predictions)
|
response = jsonify(predictions)
|
||||||
logger.debug("Analysis completed.")
|
logger.debug("Analysis completed.")
|
||||||
return response
|
return response
|
||||||
except Exception as err:
|
except Exception:
|
||||||
logger.error("Analysis failed.")
|
logger.exception(f"Analysis failed\n{traceback.format_exc()}")
|
||||||
logger.exception(err)
|
response = jsonify("Analysis failed")
|
||||||
response = jsonify("Analysis failed.")
|
|
||||||
response.status_code = 500
|
response.status_code = 500
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
@ -4,19 +4,20 @@ from image_prediction.config import CONFIG
|
|||||||
|
|
||||||
|
|
||||||
def make_logger_getter():
|
def make_logger_getter():
|
||||||
logger = logging.getLogger("imclf")
|
|
||||||
logger.propagate = False
|
|
||||||
|
|
||||||
handler = logging.StreamHandler()
|
|
||||||
handler.setLevel(CONFIG.service.logging_level)
|
|
||||||
|
|
||||||
log_format = "[%(levelname)s]: %(message)s"
|
|
||||||
formatter = logging.Formatter(log_format)
|
|
||||||
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
|
|
||||||
def get_logger():
|
def get_logger():
|
||||||
|
logger = logging.getLogger("imclf")
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setLevel(CONFIG.service.logging_level)
|
||||||
|
|
||||||
|
log_format = "[%(levelname)s]: %(message)s"
|
||||||
|
formatter = logging.Formatter(log_format)
|
||||||
|
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
return get_logger
|
return get_logger
|
||||||
|
|||||||
@ -1,13 +1,27 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from image_prediction.flask import make_prediction_server
|
from image_prediction.flask import make_prediction_server
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
logger.setLevel(logging.CRITICAL + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def predict_fn(x: bytes):
|
||||||
|
x = int(x.decode())
|
||||||
|
match x:
|
||||||
|
case 42:
|
||||||
|
return True
|
||||||
|
case _:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def server():
|
def server():
|
||||||
server = make_prediction_server(lambda _: 42)
|
server = make_prediction_server(predict_fn)
|
||||||
server.config.update({"TESTING": True})
|
server.config.update({"TESTING": True})
|
||||||
return server
|
return server
|
||||||
|
|
||||||
@ -17,16 +31,23 @@ def client(server):
|
|||||||
return server.test_client()
|
return server.test_client()
|
||||||
|
|
||||||
|
|
||||||
def test_server_predict(client):
|
def test_server_predict_success(client):
|
||||||
response = client.post("/predict")
|
response = client.post("/predict", data="42")
|
||||||
assert json.loads(response.data) == 42
|
assert json.loads(response.data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_predict_failure(client):
|
||||||
|
response = client.post("/predict", data="13")
|
||||||
|
assert response.status_code == 500
|
||||||
|
|
||||||
|
|
||||||
def test_server_health_check(client):
|
def test_server_health_check(client):
|
||||||
response = client.get("/health")
|
response = client.get("/ready")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
assert response.json == "OK"
|
||||||
|
|
||||||
|
|
||||||
def test_server_ready_check(client):
|
def test_server_ready_check(client):
|
||||||
response = client.get("/health")
|
response = client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
assert response.json == "OK"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user