TensorFlow -- 模型保存与读取

简介: 最近学习Google的深度学习框架TensorFlow,CNN模型训练什么的都是OK的,官方也有代码,文章请参照: 但是在实际使用的时候不可能每次预测都训练一遍模型,这样太浪费时间,所以需要我们在训练完成的时候保存模型,并且在需要预测的时候加载。

最近学习Google的深度学习框架TensorFlow,CNN模型训练什么的都是OK的,官方也有代码,中文详解请参照:
http://www.soaringroad.com/?p=115

但是在实际使用的时候不可能每次预测都训练一遍模型,这样太浪费时间,所以需要我们在训练完成的时候保存模型,并且在需要预测的时候加载。官方提供的例子和解释不够具体,让我踩了很多的坑,所以写个笔记分享一下,希望帮助大家跳过或者少踩这些坑。

模型保存:

①首先对于需要保存的变量进行定义,记得variable和placeholder保存用变量名的定义一定不能忘了

定义的形式大体上如下:

var_name_1= tf.Variable(........,name='var_name_1_store')
var_name_2= tf.argmax(var_name_1,name='var_name_2_store')
var_name_3=tf.placeholder(........,name='var_name_3_store')
var_name_4=tf.matmul(var_name_1,var_name_3,name='var_name_4_store')

②其次就是保存处理

需要利用 tf.train.Saver来保存模型,其中global_step不定义的情况下,默认为0

saver = tf.train.Saver()
saver.save(sess,'./data.chkp',global_step=XX)

模型加载:

①首先读取刚刚保存的meta文件,然后全局变量初始化,需要用到tf.train.import_meta_graph

saver = tf.train.import_meta_graph("./data.chkp.meta")
sess.run(tf.global_variables_initializer())

②其次加载我们需要的变量,并预测,这里用到var_name_3_store,这就是为什么前面placeholder定义的时候一定要定义name

 predict = tf.get_default_graph().get_tensor_by_name("var_name_4_store:0")
 predict.eval(feed_dist={'var_name_3_store':XXXXX})

 

相关文章
|
1月前
|
机器学习/深度学习 算法 TensorFlow
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
98 0
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
|
16天前
|
机器学习/深度学习 TensorFlow API
TensorFlow与Keras实战:构建深度学习模型
本文探讨了TensorFlow和其高级API Keras在深度学习中的应用。TensorFlow是Google开发的高性能开源框架,支持分布式计算,而Keras以其用户友好和模块化设计简化了神经网络构建。通过一个手写数字识别的实战案例,展示了如何使用Keras加载MNIST数据集、构建CNN模型、训练及评估模型,并进行预测。案例详述了数据预处理、模型构建、训练过程和预测新图像的步骤,为读者提供TensorFlow和Keras的基础实践指导。
150 59
|
1月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
关于Tensorflow!目标检测预训练模型的迁移学习
这篇文章主要介绍了使用Tensorflow进行目标检测的迁移学习过程。关于使用Tensorflow进行目标检测模型训练的实战教程,涵盖了从数据准备到模型应用的全过程,特别适合对此领域感兴趣的开发者参考。
35 3
关于Tensorflow!目标检测预训练模型的迁移学习
|
24天前
|
机器学习/深度学习 算法 TensorFlow
【图像识别】谷物识别系统Python+人工智能深度学习+TensorFlow+卷积算法网络模型+图像识别
谷物识别系统,本系统使用Python作为主要编程语言,通过TensorFlow搭建ResNet50卷积神经算法网络模型,通过对11种谷物图片数据集('大米', '小米', '燕麦', '玉米渣', '红豆', '绿豆', '花生仁', '荞麦', '黄豆', '黑米', '黑豆')进行训练,得到一个进度较高的H5格式的模型文件。然后使用Django框架搭建了一个Web网页端可视化操作界面。实现用户上传一张图片识别其名称。
65 0
【图像识别】谷物识别系统Python+人工智能深度学习+TensorFlow+卷积算法网络模型+图像识别
|
1月前
|
机器学习/深度学习 人工智能 算法
食物识别系统Python+深度学习人工智能+TensorFlow+卷积神经网络算法模型
食物识别系统采用TensorFlow的ResNet50模型,训练了包含11类食物的数据集,生成高精度H5模型。系统整合Django框架,提供网页平台,用户可上传图片进行食物识别。效果图片展示成功识别各类食物。[查看演示视频、代码及安装指南](https://www.yuque.com/ziwu/yygu3z/yhd6a7vai4o9iuys?singleDoc#)。项目利用深度学习的卷积神经网络(CNN),其局部感受野和权重共享机制适于图像识别,广泛应用于医疗图像分析等领域。示例代码展示了一个使用TensorFlow训练的简单CNN模型,用于MNIST手写数字识别。
66 3
|
1月前
|
机器学习/深度学习 TensorFlow API
Python深度学习基于Tensorflow(3)Tensorflow 构建模型
Python深度学习基于Tensorflow(3)Tensorflow 构建模型
81 2
|
28天前
|
机器学习/深度学习 人工智能 算法
中草药识别系统Python+深度学习人工智能+TensorFlow+卷积神经网络算法模型
中草药识别系统Python+深度学习人工智能+TensorFlow+卷积神经网络算法模型
65 0
|
30天前
|
机器学习/深度学习 自然语言处理 TensorFlow
构建高效的机器学习模型:基于Python和TensorFlow的实践
构建高效的机器学习模型:基于Python和TensorFlow的实践
41 0
|
1月前
|
机器学习/深度学习 大数据 TensorFlow
使用TensorFlow实现Python简版神经网络模型
使用TensorFlow实现Python简版神经网络模型
|
1月前
|
机器学习/深度学习 监控 测试技术
TensorFlow的模型评估与验证
【4月更文挑战第17天】TensorFlow是深度学习中用于模型评估与验证的重要框架,提供多样工具支持这一过程。模型评估衡量模型在未知数据上的表现,帮助识别性能和优化方向。在TensorFlow中,使用验证集和测试集评估模型,选择如准确率、召回率等指标,并通过`tf.keras.metrics`模块更新和获取评估结果。模型验证则确保模型稳定性和泛化能力,常用方法包括交叉验证和留出验证。通过这些方法,开发者能有效提升模型质量和性能。