From 3dfe7b861816ef9019103e16a23efd97a08fb617 Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Thu, 22 Sep 2022 13:53:32 +0200 Subject: [PATCH] RED-4206 wrap queue callback in process to manage memory allocation with the operating system and force deallocation after processing. --- image_prediction/flask.py | 26 +--------------------- image_prediction/utils/process_wrapping.py | 25 +++++++++++++++++++++ src/serve.py | 7 ++++++ test/unit_tests/process_wrapping_test.py | 24 ++++++++++++++++++++ 4 files changed, 57 insertions(+), 25 deletions(-) create mode 100644 image_prediction/utils/process_wrapping.py create mode 100644 test/unit_tests/process_wrapping_test.py diff --git a/image_prediction/flask.py b/image_prediction/flask.py index 4297a6f..9fe2cd2 100644 --- a/image_prediction/flask.py +++ b/image_prediction/flask.py @@ -1,38 +1,14 @@ -import multiprocessing -import traceback from typing import Callable from flask import Flask, request, jsonify from prometheus_client import generate_latest, CollectorRegistry, Summary from image_prediction.utils import get_logger +from image_prediction.utils.process_wrapping import wrap_in_process logger = get_logger() -def run_in_process(func): - p = multiprocessing.Process(target=func) - p.start() - p.join() - - -def wrap_in_process(func_to_wrap): - def build_function_and_run_in_process(*args, **kwargs): - def func(): - try: - result = func_to_wrap(*args, **kwargs) - return_dict["result"] = result - except: - logger.error(traceback.format_exc()) - - manager = multiprocessing.Manager() - return_dict = manager.dict() - run_in_process(func) - return return_dict.get("result", None) - - return build_function_and_run_in_process - - def make_prediction_server(predict_fn: Callable): app = Flask(__name__) registry = CollectorRegistry(auto_describe=True) diff --git a/image_prediction/utils/process_wrapping.py b/image_prediction/utils/process_wrapping.py new file mode 100644 index 0000000..516e40b --- /dev/null +++ b/image_prediction/utils/process_wrapping.py @@ -0,0 +1,25 @@ +import logging +import multiprocessing + + +logger = logging.getLogger("main") + + +def wrap_in_process(fn): + manager = multiprocessing.Manager() + return_queue = manager.list() + + def process_fn(*args, **kwargs): + return_queue.append(fn(*args, **kwargs)) + + def wrapped_fn(*args, **kwargs): + logger.debug("Starting new subprocess") + process = multiprocessing.Process(target=process_fn, args=args, kwargs=kwargs) + process.start() + process.join() + try: + return return_queue.pop(0) + except IndexError: + logger.warning("No results returned by subprocess.") + + return wrapped_fn diff --git a/src/serve.py b/src/serve.py index 65260da..ece6a0b 100644 --- a/src/serve.py +++ b/src/serve.py @@ -7,6 +7,7 @@ from image_prediction.config import Config from image_prediction.locations import CONFIG_FILE from image_prediction.pipeline import load_pipeline from image_prediction.utils.banner import load_banner +from image_prediction.utils.process_wrapping import wrap_in_process from pyinfra import config from pyinfra.queue.queue_manager import QueueManager from pyinfra.storage.storage import get_storage @@ -19,6 +20,12 @@ logger = logging.getLogger("main") logger.setLevel(PYINFRA_CONFIG.logging_level_root) +# A component of the callback (probably tensorflow) does not release allocated memory (see RED-4206). +# See: https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution +# Workaround: Manage Memory with the operating system, by wrapping the callback in a sub-process. +# FIXME: Find more fine-grained solution or if the problem occurs persistently for python services, +# FIXME: move the process wrapper to a general module (see RED-4929). +@wrap_in_process def process_request(request_message): dossier_id = request_message["dossierId"] file_id = request_message["fileId"] diff --git a/test/unit_tests/process_wrapping_test.py b/test/unit_tests/process_wrapping_test.py new file mode 100644 index 0000000..d5be35e --- /dev/null +++ b/test/unit_tests/process_wrapping_test.py @@ -0,0 +1,24 @@ +import pytest + +from image_prediction.utils.process_wrapping import wrap_in_process + + +@pytest.fixture +def process_fn_mock(): + def _process(a: str, b: float, c: dict, d: list): + return a * 2, b + 3, c, d[0] + + return _process + + +@pytest.fixture +def parameter(): + return {"a": "A", "b": 0.42, "c": {"x": 1, "y": 2}, "d": [1, 2, 3]} + + +def test_process_wrapper_with_args(process_fn_mock, parameter): + assert process_fn_mock(*parameter.values()) == wrap_in_process(process_fn_mock)(*parameter.values()) + + +def test_process_wrapper_with_kwargs(process_fn_mock, parameter): + assert process_fn_mock(**parameter) == wrap_in_process(process_fn_mock)(**parameter)