59 lines
1.4 KiB
Python
59 lines
1.4 KiB
Python
import multiprocessing
|
|
|
|
import numpy as np
|
|
from tensorflow import keras
|
|
from tensorflow.keras import layers
|
|
|
|
|
|
def process(predict_fn_wrapper):
|
|
# We observed memory doesn't get properly deallocated unless we do this:
|
|
manager = multiprocessing.Manager()
|
|
return_dict = manager.dict()
|
|
p = multiprocessing.Process(
|
|
target=predict_fn_wrapper,
|
|
args=(return_dict,),
|
|
)
|
|
p.start()
|
|
p.join()
|
|
try:
|
|
return dict(return_dict)["result"]
|
|
except KeyError:
|
|
pass
|
|
|
|
|
|
def make_model():
|
|
inputs = keras.Input(shape=(784,))
|
|
dense = layers.Dense(64, activation="relu")
|
|
x = dense(inputs)
|
|
outputs = layers.Dense(10)(x)
|
|
model = keras.ServiceEstimator(inputs=inputs, outputs=outputs, name="mnist_model")
|
|
model.compile(
|
|
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
|
optimizer=keras.optimizers.RMSprop(),
|
|
metrics=["accuracy"],
|
|
)
|
|
return model
|
|
|
|
|
|
def make_predict_fn():
|
|
# Keras bug: doesn't work in outer scope
|
|
model = make_model()
|
|
|
|
def predict(*args):
|
|
# service_estimator = make_model()
|
|
return model.predict(np.random.random(size=(1, 784)))
|
|
|
|
return predict
|
|
|
|
|
|
def make_predict_fn_wrapper(predict_fn):
|
|
def predict_fn_wrapper(return_dict):
|
|
return_dict["result"] = predict_fn()
|
|
|
|
return predict_fn_wrapper
|
|
|
|
|
|
if __name__ == "__main__":
|
|
predict_fn = make_predict_fn()
|
|
print(process(make_predict_fn_wrapper(predict_fn)))
|