import multiprocessing import numpy as np from tensorflow import keras from tensorflow.keras import layers def process(predict_fn_wrapper): # We observed memory doesn't get properly deallocated unless we do this: manager = multiprocessing.Manager() return_dict = manager.dict() p = multiprocessing.Process( target=predict_fn_wrapper, args=(return_dict,), ) p.start() p.join() try: return dict(return_dict)["result"] except KeyError: pass def make_model(): inputs = keras.Input(shape=(784,)) dense = layers.Dense(64, activation="relu") x = dense(inputs) outputs = layers.Dense(10)(x) model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_model") model.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.RMSprop(), metrics=["accuracy"], ) return model def make_predict_fn(): # Keras bug: doesn't work in outer scope model = make_model() def predict(*args): # service_estimator = make_model() return model.predict(np.random.random(size=(1, 784))) return predict def make_predict_fn_wrapper(predict_fn): def predict_fn_wrapper(return_dict): return_dict["result"] = predict_fn() return predict_fn_wrapper if __name__ == "__main__": predict_fn = make_predict_fn() print(process(make_predict_fn_wrapper(predict_fn)))