From 130d0e8b23e0375a6fd240ac8aa00492c341a716 Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Mon, 21 Mar 2022 13:34:54 +0100 Subject: [PATCH] add minimal not working example for keras bug in multiprocess process --- scripts/keras_MnWE.py | 60 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 scripts/keras_MnWE.py diff --git a/scripts/keras_MnWE.py b/scripts/keras_MnWE.py new file mode 100644 index 0000000..4936be8 --- /dev/null +++ b/scripts/keras_MnWE.py @@ -0,0 +1,60 @@ +import multiprocessing + +import numpy as np +from tensorflow import keras +from tensorflow.keras import layers + + +def process(predict_fn_wrapper): + # We observed memory doesn't get properly deallocated unless we do this: + manager = multiprocessing.Manager() + return_dict = manager.dict() + p = multiprocessing.Process( + target=predict_fn_wrapper, + args=( + return_dict, + ), + ) + p.start() + p.join() + try: + return dict(return_dict)["result"] + except KeyError: + pass + + +def make_model(): + inputs = keras.Input(shape=(784,)) + dense = layers.Dense(64, activation="relu") + x = dense(inputs) + outputs = layers.Dense(10)(x) + model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_model") + model.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.RMSprop(), + metrics=["accuracy"], + ) + return model + + +def make_predict_fn(): + # Keras bug: doesn't work in outer scope + model = make_model() + + def predict(*args): + # model = make_model() + return model.predict(np.random.random(size=(1, 784))) + + return predict + + +def make_predict_fn_wrapper(predict_fn): + def predict_fn_wrapper(return_dict): + return_dict["result"] = predict_fn() + + return predict_fn_wrapper + + +if __name__ == "__main__": + predict_fn = make_predict_fn() + print(process(make_predict_fn_wrapper(predict_fn)))