Tensorflow |(5)模型保存与恢复、自定义命令行参数

简介: Tensorflow |(5)模型保存与恢复、自定义命令行参数

Tensorflow |(1)初识Tensorflow


Tensorflow |(2)张量的阶和数据类型及张量操作


Tensorflow |(3)变量的的创建、初始化、保存和加载


Tensorflow |(4)名称域、图 和会话


Tensorflow |(5)模型保存与恢复、自定义命令行参数


模型保存与恢复、自定义命令行参数、

在我们训练或者测试过程中,总会遇到需要保存训练完成的模型,然后从中恢复继续我们的测试或者其它使用。模型的保存和恢复也是通过tf.train.Saver类去实现,它主要通过将Saver类添加OPS保存和恢复变量到checkpoint。它还提供了运行这些操作的便利方法。


tf.train.Saver(var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=tf.SaverDef.V2, pad_step_number=False)


var_list:指定将要保存和还原的变量。它可以作为一个dict或一个列表传递.

max_to_keep:指示要保留的最近检查点文件的最大数量。创建新文件时,会删除较旧的文件。如果无或0,则保留所有检查点文件。默认为5(即保留最新的5个检查点文件。)

keep_checkpoint_every_n_hours:多久生成一个新的检查点文件。默认为10,000小时

保存

保存我们的模型需要调用Saver.save()方法。save(sess, save_path, global_step=None),checkpoint是专有格式的二进制文件,将变量名称映射到张量值。

————————————————

版权声明:本文为CSDN博主「DrugAI」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。

原文链接:https://blog.csdn.net/u012325865/article/details/104346708Tensorflow |(1)初识Tensorflow


Tensorflow |(2)张量的阶和数据类型及张量操作


Tensorflow |(3)变量的的创建、初始化、保存和加载


Tensorflow |(4)名称域、图 和会话


Tensorflow |(5)模型保存与恢复、自定义命令行参数


模型保存与恢复、自定义命令行参数、

在我们训练或者测试过程中,总会遇到需要保存训练完成的模型,然后从中恢复继续我们的测试或者其它使用。模型的保存和恢复也是通过tf.train.Saver类去实现,它主要通过将Saver类添加OPS保存和恢复变量到checkpoint。它还提供了运行这些操作的便利方法。


tf.train.Saver(var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=tf.SaverDef.V2, pad_step_number=False)


var_list:指定将要保存和还原的变量。它可以作为一个dict或一个列表传递.

max_to_keep:指示要保留的最近检查点文件的最大数量。创建新文件时,会删除较旧的文件。如果无或0,则保留所有检查点文件。默认为5(即保留最新的5个检查点文件。)

keep_checkpoint_every_n_hours:多久生成一个新的检查点文件。默认为10,000小时

保存

保存我们的模型需要调用Saver.save()方法。save(sess, save_path, global_step=None),checkpoint是专有格式的二进制文件,将变量名称映射到张量值。

import tensorflow as tf
a = tf.Variable([[1.0,2.0]],name="a")
b = tf.Variable([[3.0],[4.0]],name="b")
c = tf.matmul(a,b)
saver=tf.train.Saver()
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(sess.run(c))
    saver.save(sess, '/tmp/ckpt/test/matmul')

我们可以看保存了什么文件

image.png

在多次训练的时候可以指定多少间隔生成检查点文件

saver.save(sess, '/tmp/ckpt/test/matmu', global_step=0) ==> filename: 'matmu-0'
saver.save(sess, '/tmp/ckpt/test/matmu', global_step=1000) ==> filename: 'matmu-1000'


恢复

恢复模型的方法是restore(sess, save_path),save_path是以前保存参数的路径,我们可以使用tf.train.latest_checkpoint来获取最近的检查点文件(也恶意直接写文件目录)

import tensorflow as tf
a = tf.Variable([[1.0,2.0]],name="a")
b = tf.Variable([[3.0],[4.0]],name="b")
c = tf.matmul(a,b)
saver=tf.train.Saver(max_to_keep=1)
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(sess.run(c))
    saver.save(sess, '/tmp/ckpt/test/matmul')
    # 恢复模型
    model_file = tf.train.latest_checkpoint('/tmp/ckpt/test/')
    saver.restore(sess, model_file)
    print(sess.run([c], feed_dict={a: [[5.0,6.0]], b: [[7.0],[8.0]]}))

自定义命令行参数

tf.app.run(),默认调用main()函数,运行程序。main(argv)必须传一个参数。


tf.app.flags,它支持应用从命令行接受参数,可以用来指定集群配置等。在tf.app.flags下面有各种定义参数的类型


DEFINE_string(flag_name, default_value, docstring)


DEFINE_integer(flag_name, default_value, docstring)


DEFINE_boolean(flag_name, default_value, docstring)


DEFINE_float(flag_name, default_value, docstring)


第一个也就是参数的名字,路径、大小等等。第二个参数提供具体的值。第三个参数是说明文档


tf.app.flags.FLAGS,在flags有一个FLAGS标志,它在程序中可以调用到我们前面具体定义的flag_name.

import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('data_dir', '/tmp/tensorflow/mnist/input_data',
                           """数据集目录""")
tf.app.flags.DEFINE_integer('max_steps', 2000,
                            """训练次数""")
tf.app.flags.DEFINE_string('summary_dir', '/tmp/summary/mnist/convtrain',
                           """事件文件目录""")
def main(argv):
    print(FLAGS.data_dir)
    print(FLAGS.max_steps)
    print(FLAGS.summary_dir)
    print(argv)
if __name__=="__main__":
    tf.app.run()


目录
相关文章
|
3月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
118 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
2月前
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
195 5
|
2月前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
134 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
2月前
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
136 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
2月前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
111 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
4月前
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
143 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
|
3月前
|
机器学习/深度学习 移动开发 TensorFlow
深度学习之格式转换笔记(四):Keras(.h5)模型转化为TensorFlow(.pb)模型
本文介绍了如何使用Python脚本将Keras模型转换为TensorFlow的.pb格式模型,包括加载模型、重命名输出节点和量化等步骤,以便在TensorFlow中进行部署和推理。
166 0
|
5月前
|
机器学习/深度学习 存储 前端开发
实战揭秘:如何借助TensorFlow.js的强大力量,轻松将高效能的机器学习模型无缝集成到Web浏览器中,从而打造智能化的前端应用并优化用户体验
【8月更文挑战第31天】将机器学习模型集成到Web应用中,可让用户在浏览器内体验智能化功能。TensorFlow.js作为在客户端浏览器中运行的库,提供了强大支持。本文通过问答形式详细介绍如何使用TensorFlow.js将机器学习模型带入Web浏览器,并通过具体示例代码展示最佳实践。首先,需在HTML文件中引入TensorFlow.js库;接着,可通过加载预训练模型如MobileNet实现图像分类;然后,编写代码处理图像识别并显示结果;此外,还介绍了如何训练自定义模型及优化模型性能的方法,包括模型量化、剪枝和压缩等。
96 1
|
5月前
|
机器学习/深度学习 人工智能 PyTorch
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
93 1
|
5月前
|
API UED 开发者
如何在Uno Platform中轻松实现流畅动画效果——从基础到优化,全方位打造用户友好的动态交互体验!
【8月更文挑战第31天】在开发跨平台应用时,确保用户界面流畅且具吸引力至关重要。Uno Platform 作为多端统一的开发框架,不仅支持跨系统应用开发,还能通过优化实现流畅动画,增强用户体验。本文探讨了Uno Platform中实现流畅动画的多个方面,包括动画基础、性能优化、实践技巧及问题排查,帮助开发者掌握具体优化策略,提升应用质量与用户满意度。通过合理利用故事板、减少布局复杂性、使用硬件加速等技术,结合异步方法与预设缓存技巧,开发者能够创建美观且流畅的动画效果。
100 0

热门文章

最新文章