开发者学堂课程【深度学习框架 TensorFlow 入门:模型保存与加载】学习笔记,与课程紧密联系,让用户快速学习知识。
课程地址:https://developer.aliyun.com/learning/course/773/detail/13550
模型保存与加载
内容介绍:
一、模型的保存与加载
二、实现代码
一、模型的保存与加载
tf.train.Saver(var_list= None,max.to_keep=5)
。保存和加载模型(保存文件格式: checkpoint 文件)
。var_list:指定将要保存和还原的变量,它可以作为-个dict或一个列表传递.
。max_to_keep:指示要保留的最近检查点文件的最大数量。创建新文件时,会删除较旧的文件。如果无或0,则保留所有检查点文件。默认为5 (即保留最新的5个检查点文件。)
使用
例如:
指定目录+模型名字
saver.save(sess,'/tmp/ckpt/test/myregression.ckpt')
saver.restore(sess,’/tmp/ckpt/test/myregression.ckpt')
如要判断模型是否存在,直接指定目录
checkpoint = tf.train.latest_checkpoint("./tmp/modeL/")
saver.restore(sess,checkpoint )
二、实现代码
#创建 Saver 对象,
Saver=tf.train.Saver()
#保存模型
if i % 10==0:
saver.save(sess,"./tmp/modeL/my_linear.ckpt")
#加载模型
if os.path.exists("./tmp/modeL/checkpoint"):
saver.restore(sess,"./tmp/modeL/my_Linear.ckpt")
print(“训练后模型参数为:权重%f,偏置%f,损失为%f" % (weights. eval(), bias.eval(), error.eval()))