From 1776e3083c97025e699d579f936dd0cc6e1fe152 Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Mon, 21 Mar 2022 13:54:27 +0100 Subject: [PATCH] blacckkkyykykykyk --- image_prediction/flask.py | 1 - scripts/keras_MnWE.py | 4 +--- src/serve.py | 1 - test/conftest.py | 2 +- test/unit_tests/test_predictor.py | 19 +++++++++++++++---- 5 files changed, 17 insertions(+), 10 deletions(-) diff --git a/image_prediction/flask.py b/image_prediction/flask.py index 11aa356..5cf40c2 100644 --- a/image_prediction/flask.py +++ b/image_prediction/flask.py @@ -26,7 +26,6 @@ def make_prediction_server(predict_fn: Callable): @app.route("/", methods=["POST"]) def predict(): - def predict_fn_wrapper(pdf, return_dict): return_dict["result"] = predict_fn(pdf) diff --git a/scripts/keras_MnWE.py b/scripts/keras_MnWE.py index 4936be8..05a45dd 100644 --- a/scripts/keras_MnWE.py +++ b/scripts/keras_MnWE.py @@ -11,9 +11,7 @@ def process(predict_fn_wrapper): return_dict = manager.dict() p = multiprocessing.Process( target=predict_fn_wrapper, - args=( - return_dict, - ), + args=(return_dict,), ) p.start() p.join() diff --git a/src/serve.py b/src/serve.py index af7a133..666ca80 100644 --- a/src/serve.py +++ b/src/serve.py @@ -12,7 +12,6 @@ logger = get_logger() def main(): - def predict(pdf): # Keras model.predict stalls when model was loaded in different process # https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python diff --git a/test/conftest.py b/test/conftest.py index 73f2e37..71b37d1 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -67,4 +67,4 @@ def predictor(): @pytest.fixture def test_pdf(): with open("./test/test_data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf", "rb") as f: - return f.read() \ No newline at end of file + return f.read() diff --git a/test/unit_tests/test_predictor.py b/test/unit_tests/test_predictor.py index 522ec4b..0da6f91 100644 --- a/test/unit_tests/test_predictor.py +++ b/test/unit_tests/test_predictor.py @@ -9,7 +9,18 @@ def test_predict_pdf_works(predictor, test_pdf): metadata = list(metadata) metadata = dict(**metadata[0]) metadata.pop("document_filename") # temp filename cannot be tested - assert metadata == {'px_width': 389.0, 'px_height': 389.0, 'width': 194.49999000000003, - 'height': 194.49998999999997, 'x1': 320.861, 'x2': 515.36099, 'y1': 347.699, 'y2': 542.19899, - 'page_width': 595.2800000000001, 'page_height': 841.89, 'page_rotation': 0, 'page_idx': 1, - 'n_pages': 3} + assert metadata == { + "px_width": 389.0, + "px_height": 389.0, + "width": 194.49999000000003, + "height": 194.49998999999997, + "x1": 320.861, + "x2": 515.36099, + "y1": 347.699, + "y2": 542.19899, + "page_width": 595.2800000000001, + "page_height": 841.89, + "page_rotation": 0, + "page_idx": 1, + "n_pages": 3, + }