From 150165367339a4863b2a48036c649fbd2203dd3e Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Sat, 2 Apr 2022 00:16:01 +0200 Subject: [PATCH] coverage increased for flask tests --- image_prediction/flask.py | 21 ++++++------------ image_prediction/utils/logger.py | 23 ++++++++++---------- test/unit_tests/mocked_server_test.py | 31 ++++++++++++++++++++++----- 3 files changed, 44 insertions(+), 31 deletions(-) diff --git a/image_prediction/flask.py b/image_prediction/flask.py index 97cd999..7bc6eb1 100644 --- a/image_prediction/flask.py +++ b/image_prediction/flask.py @@ -1,4 +1,5 @@ import multiprocessing +import traceback from typing import Callable from flask import Flask, request, jsonify @@ -39,19 +40,10 @@ def make_prediction_server(predict_fn: Callable): pdf = request.data manager = multiprocessing.Manager() return_dict = manager.dict() - p = multiprocessing.Process( - target=predict_fn_wrapper, - args=( - pdf, - return_dict, - ), - ) + p = multiprocessing.Process(target=predict_fn_wrapper, args=(pdf, return_dict)) p.start() p.join() - try: - return dict(return_dict)["result"] - except KeyError: - raise + return return_dict["result"] logger.info("Analysing document...") try: @@ -59,10 +51,9 @@ def make_prediction_server(predict_fn: Callable): response = jsonify(predictions) logger.debug("Analysis completed.") return response - except Exception as err: - logger.error("Analysis failed.") - logger.exception(err) - response = jsonify("Analysis failed.") + except Exception: + logger.exception(f"Analysis failed\n{traceback.format_exc()}") + response = jsonify("Analysis failed") response.status_code = 500 return response diff --git a/image_prediction/utils/logger.py b/image_prediction/utils/logger.py index c712da4..b2a7767 100644 --- a/image_prediction/utils/logger.py +++ b/image_prediction/utils/logger.py @@ -4,19 +4,20 @@ from image_prediction.config import CONFIG def make_logger_getter(): - logger = logging.getLogger("imclf") - logger.propagate = False - - handler = logging.StreamHandler() - handler.setLevel(CONFIG.service.logging_level) - - log_format = "[%(levelname)s]: %(message)s" - formatter = logging.Formatter(log_format) - - handler.setFormatter(formatter) - logger.addHandler(handler) def get_logger(): + logger = logging.getLogger("imclf") + logger.propagate = False + + handler = logging.StreamHandler() + handler.setLevel(CONFIG.service.logging_level) + + log_format = "[%(levelname)s]: %(message)s" + formatter = logging.Formatter(log_format) + + handler.setFormatter(formatter) + logger.addHandler(handler) + return logger return get_logger diff --git a/test/unit_tests/mocked_server_test.py b/test/unit_tests/mocked_server_test.py index b6bb083..ab82a33 100644 --- a/test/unit_tests/mocked_server_test.py +++ b/test/unit_tests/mocked_server_test.py @@ -1,13 +1,27 @@ import json +import logging import pytest from image_prediction.flask import make_prediction_server +from image_prediction.utils import get_logger + +logger = get_logger() +logger.setLevel(logging.CRITICAL + 1) + + +def predict_fn(x: bytes): + x = int(x.decode()) + match x: + case 42: + return True + case _: + raise Exception @pytest.fixture def server(): - server = make_prediction_server(lambda _: 42) + server = make_prediction_server(predict_fn) server.config.update({"TESTING": True}) return server @@ -17,16 +31,23 @@ def client(server): return server.test_client() -def test_server_predict(client): - response = client.post("/predict") - assert json.loads(response.data) == 42 +def test_server_predict_success(client): + response = client.post("/predict", data="42") + assert json.loads(response.data) + + +def test_server_predict_failure(client): + response = client.post("/predict", data="13") + assert response.status_code == 500 def test_server_health_check(client): - response = client.get("/health") + response = client.get("/ready") assert response.status_code == 200 + assert response.json == "OK" def test_server_ready_check(client): response = client.get("/health") assert response.status_code == 200 + assert response.json == "OK"