2022-03-25 11:42:31 +01:00

59 lines
1.4 KiB
Python

import multiprocessing
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
def process(predict_fn_wrapper):
# We observed memory doesn't get properly deallocated unless we do this:
manager = multiprocessing.Manager()
return_dict = manager.dict()
p = multiprocessing.Process(
target=predict_fn_wrapper,
args=(return_dict,),
)
p.start()
p.join()
try:
return dict(return_dict)["result"]
except KeyError:
pass
def make_model():
inputs = keras.Input(shape=(784,))
dense = layers.Dense(64, activation="relu")
x = dense(inputs)
outputs = layers.Dense(10)(x)
model = keras.ServiceEstimator(inputs=inputs, outputs=outputs, name="mnist_model")
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.RMSprop(),
metrics=["accuracy"],
)
return model
def make_predict_fn():
# Keras bug: doesn't work in outer scope
model = make_model()
def predict(*args):
# service_estimator = make_model()
return model.predict(np.random.random(size=(1, 784)))
return predict
def make_predict_fn_wrapper(predict_fn):
def predict_fn_wrapper(return_dict):
return_dict["result"] = predict_fn()
return predict_fn_wrapper
if __name__ == "__main__":
predict_fn = make_predict_fn()
print(process(make_predict_fn_wrapper(predict_fn)))