add minimal not working example for keras bug in multiprocess process
This commit is contained in:
parent
2589598b05
commit
130d0e8b23
60
scripts/keras_MnWE.py
Normal file
60
scripts/keras_MnWE.py
Normal file
@ -0,0 +1,60 @@
|
||||
import multiprocessing
|
||||
|
||||
import numpy as np
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras import layers
|
||||
|
||||
|
||||
def process(predict_fn_wrapper):
|
||||
# We observed memory doesn't get properly deallocated unless we do this:
|
||||
manager = multiprocessing.Manager()
|
||||
return_dict = manager.dict()
|
||||
p = multiprocessing.Process(
|
||||
target=predict_fn_wrapper,
|
||||
args=(
|
||||
return_dict,
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
try:
|
||||
return dict(return_dict)["result"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
def make_model():
|
||||
inputs = keras.Input(shape=(784,))
|
||||
dense = layers.Dense(64, activation="relu")
|
||||
x = dense(inputs)
|
||||
outputs = layers.Dense(10)(x)
|
||||
model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_model")
|
||||
model.compile(
|
||||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||||
optimizer=keras.optimizers.RMSprop(),
|
||||
metrics=["accuracy"],
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def make_predict_fn():
|
||||
# Keras bug: doesn't work in outer scope
|
||||
model = make_model()
|
||||
|
||||
def predict(*args):
|
||||
# model = make_model()
|
||||
return model.predict(np.random.random(size=(1, 784)))
|
||||
|
||||
return predict
|
||||
|
||||
|
||||
def make_predict_fn_wrapper(predict_fn):
|
||||
def predict_fn_wrapper(return_dict):
|
||||
return_dict["result"] = predict_fn()
|
||||
|
||||
return predict_fn_wrapper
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
predict_fn = make_predict_fn()
|
||||
print(process(make_predict_fn_wrapper(predict_fn)))
|
||||
Loading…
x
Reference in New Issue
Block a user