1 原因
提高GPU利用率
2 Example
参考官网的介绍通过 Keras 模型创建 Estimator
# 通过keras API 构建模型
model = build_model()
# 产生训练集sample 和label
x,y = generator_data(data_size,SNRdb)
# 用Dataset封装,加快训练
dataset_xy=tf.data.Dataset.from_tensor_slices((x,y)).shuffle(5000).batch(batchs).prefetch(tf.data.experimental.AUTOTUNE).repeat()
# 临时文件
import tempfile
model_dir = tempfile.mkdtemp()
# 用Estimator进行训练
keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model,model_dir=model_dir)
# 预测
valid_data =...#Dataset格式的验证集
eval_result = keras_estimator.evaluate(input_fn=valid_data, steps=10)
print('Eval result: {}'.format(eval_result))