深度学习小技巧(一):如何保存和恢复TensorFlow训练的模型

简介: 深度学习小技巧掌握:作者通过一个简单的例子详细介绍了如何将训练过程中的深度学习模型保存,然后如何加载。有了这个小技巧,再也不用担心在训练模型中出错了。

更多深度文章,请关注云计算频道:https://yq.aliyun.com/cloud


深度学习小技巧(二):如何保存和恢复scikit-learn训练的模型

如果深层神经网络模型的复杂度非常高的话,那么训练它可能需要相当长的一段时间,当然这也取决于你拥有的数据量,运行模型的硬件等等。在大多数情况下,你需要通过保存文件来保障你试验的稳定性,防止如果中断(或一个错误),你能够继续从没有错误的地方开始。

更重要的是,对于任何深度学习的框架,像TensorFlow,在成功的训练之后,你需要重新使用模型的学习参数来完成对新数据的预测。

在这篇文章中,我们来看一下如何保存和恢复TensorFlow模型,我们在此介绍一些最有用的方法,并提供一些例子。

1.首先我们将快速介绍TensorFlow模型

TensorFlow的主要功能是通过张量来传递其基本数据结构类似于NumPy中的多维数组,而图表则表示数据计算。它是一个符号库,这意味着定义图形和张量将仅创建一个模型,而获取张量的具体值和操作将在会话(session)中执行,会话(session)一种在图中执行建模操作的机制。会话关闭时,张量的任何具体值都会丢失,这也是运行会话后将模型保存到文件的另一个原因。

通过示例可以帮助我们更容易理解,所以让我们为二维数据的线性回归创建一个简单的TensorFlow模型。

首先,我们将导入我们的库:


import tensorflow as tf  
import numpy as np  
import matplotlib.pyplot as plt  
%matplotlib inline

下一步是创建模型。我们将生成一个模型,它将以以下的形式估算二次函数的水平和垂直位移:

y = (x - h) ^ 2 + v  

其中h是水平和v是垂直的变化。

以下是如何生成模型的过程(有关详细信息,请参阅代码中的注释):


# Clear the current graph in each run, to avoid variable duplication
tf.reset_default_graph()
# Create placeholders for the x and y points
X = tf.placeholder("float")  
Y = tf.placeholder("float")
# Initialize the two parameters that need to be learned
h_est = tf.Variable(0.0, name='hor_estimate')  
v_est = tf.Variable(0.0, name='ver_estimate')
# y_est holds the estimated values on y-axis
y_est = tf.square(X - h_est) + v_est
# Define a cost function as the squared distance between Y and y_est
cost = (tf.pow(Y - y_est, 2))
# The training operation for minimizing the cost function. The
# learning rate is 0.001
trainop = tf.train.GradientDescentOptimizer(0.001).minimize(cost) 

在创建模型的过程中,我们需要有一个在会话中运行的模型,并且传递一些真实的数据。我们生成一些二次数据(Quadratic data),并给他们添加噪声。

# Use some values for the horizontal and vertical shift
h = 1  
v = -2
# Generate training data with noise
x_train = np.linspace(-2,4,201)  
noise = np.random.randn(*x_train.shape) * 0.4  
y_train = (x_train - h) ** 2 + v + noise
# Visualize the data 
plt.rcParams['figure.figsize'] = (10, 6)  
plt.scatter(x_train, y_train)  
plt.xlabel('x_train')  
plt.ylabel('y_train')  

7623c197c4b1b21b779b85dcfa10591975082dbf

2.The Saver class

Saver类是TensorFlow库提供的类,它是保存图形结构和变量的首选方法。

2.1保存模型

在以下几行代码中,我们定义一个Saver对象,并在train_graph()函数中,经过100次迭代的方法最小化成本函数。然后,在每次迭代中以及优化完成后,将模型保存到磁盘。每个保存在磁盘上创建二进制文件被称为“检查点”。


# Create a Saver object
saver = tf.train.Saver()

init = tf.global_variables_initializer()

# Run a session. Go through 100 iterations to minimize the cost
def train_graph():  
    with tf.Session() as sess:
        sess.run(init)
        for i in range(100):
            for (x, y) in zip(x_train, y_train):

                # Feed actual data to the train operation
                sess.run(trainop, feed_dict={X: x, Y: y})

            # Create a checkpoint in every iteration
            saver.save(sess, 'model_iter', global_step=i)

        # Save the final model
        saver.save(sess, 'model_final')
        h_ = sess.run(h_est)
        v_ = sess.run(v_est)
    return h_, v_

现在让我们用上述功能训练模型,并打印出训练的参数。


result = train_graph()  
print("h_est = %.2f, v_est = %.2f" % result)  

$ python tf_save.py
h_est = 1.01, v_est = -1.96  

Okay,参数是非常准确的。如果我们检查我们的文件系统,最后4次迭代中保存有文件以及最终的模型。

保存模型时,你会注意到需要4种类型的文件才能保存:

“.meta”文件:包含图形结构。

“.data”文件:包含变量的值。

“.index”文件:标识检查点。

“checkpoint”文件:具有最近检查点列表的协议缓冲区。

b6f873c4df16be6443a7374a24f91f5344f95a97

图1:检查点文件保存到磁盘

调用tf.train.Saver()方法,如上所示,将所有变量保存到一个文件。通过将它们作为参数,表情通过列表或dict传递来保存变量的子集,例如:tf.train.Saver({'hor_estimate': h_est})

Saver构造函数的一些其他有用的参数,也可以控制整个过程,它们是:

1.max_to_keep:最多保留的检查点数。

2.keep_checkpoint_every_n_hours:保存检查点的时间间隔。

如果你想要了解更多信息,请查看官方文档Saver类,它提供了其它有用的信息,你可以探索查看。

3.Restoring Models

恢复TensorFlow模型时要做的第一件事就是将图形结构从“.meta”文件加载到当前图形中。

tf.reset_default_graph()  
imported_meta = tf.train.import_meta_graph("model_final.meta")  

也可以使用以下命令探索当前图形tf.get_default_graph()。接着第二步是加载变量的值。提醒:值仅存在于会话(session)中。

with tf.Session() as sess:  
    imported_meta.restore(sess, tf.train.latest_checkpoint('./'))
    h_est2 = sess.run('hor_estimate:0')
    v_est2 = sess.run('ver_estimate:0')
    print("h_est: %.2f, v_est: %.2f" % (h_est2, v_est2))


$ python tf_restore.py
INFO:tensorflow:Restoring parameters from ./model_final  
h_est: 1.01, v_est: -1.96  

如前面所提到的,这种方法只保存图形结构和变量,这意味着通过占位符“X”和“Y”输入的训练数据不会被保存。

无论如何,在这个例子中,我们将使用我们定义的训练数据tf,并且可视化模型拟合。

plt.scatter(x_train, y_train, label='train data')  
plt.plot(x_train, (x_train - h_est2) ** 2 + v_est2, color='red', label='model')  
plt.xlabel('x_train')  
plt.ylabel('y_train')  
plt.legend() 

097e9c2626fe09aeb71bc947711b1f29ebc8ae4c

Saver这个类允许使用一个简单的方法来保存和恢复你的TensorFlow模型(图形和变量)到/从文件,并保留你工作中的多个检查点,这可能是有用的,它可以帮助你的模型在训练过程中进行微调。

4.SavedModel格式(Format)

在TensorFlow中保存和恢复模型的一种新方法是使用SavedModel,Builder和loader功能。这个方法实际上是Saver提供的更高级别的序列化,它更适合于商业目的。

虽然这种SavedModel方法似乎不被开发人员完全接受,但它的创作者指出:它显然是未来。与Saver主要关注变量的类相比,SavedModel尝试将一些有用的功能包含在一个包中,例如Signatures:允许保存具有一组输入和输出的图形,Assets:包含初始化中使用的外部文件。

4.1使用SavedModel Builder保存模型

接下来我们尝试使用SavedModelBuilder类完成模型的保存。在我们的示例中,我们不使用任何符号,但也足以说明该过程。

tf.reset_default_graph()
# Re-initialize our two variables
h_est = tf.Variable(h_est2, name='hor_estimate2')  
v_est = tf.Variable(v_est2, name='ver_estimate2')

# Create a builder
builder = tf.saved_model.builder.SavedModelBuilder('./SavedModel/')

# Add graph and variables to builder and save
with tf.Session() as sess:  
    sess.run(h_est.initializer)
    sess.run(v_est.initializer)
    builder.add_meta_graph_and_variables(sess,
                                       [tf.saved_model.tag_constants.TRAINING],
                                       signature_def_map=None,
                                       assets_collection=None)
builder.save()  

$ python tf_saved_model_builder.py
INFO:tensorflow:No assets to save.  
INFO:tensorflow:No assets to write.  
INFO:tensorflow:SavedModel written to: b'./SavedModel/saved_model.pb' 

运行此代码时,你会注意到我们的模型已保存到位于“./SavedModel/saved_model.pb”的文件中。

4.2使用SavedModel Loader程序恢复模型

模型恢复使用tf.saved_model.loader并且可以恢复会话范围中保存的变量,符号。

在下面的例子中,我们将加载模型,并打印出我们的两个系数(h_estv_est)的数值。数值如预期的那样,我们的模型已经被成功地恢复了。

with tf.Session() as sess:  
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], './SavedModel/')
    h_est = sess.run('hor_estimate2:0')
    v_est = sess.run('ver_estimate2:0')
    print("h_est: %.2f, v_est: %.2f" % (h_est, v_est))

$ python tf_saved_model_loader.py
INFO:tensorflow:Restoring parameters from b'./SavedModel/variables/variables'  
h_est: 1.01, v_est: -1.96  

5.结论

如果你知道你的深度学习网络的训练可能会花费很长时间,保存和恢复TensorFlow模型是非常有用的功能。该主题太广泛,无法在一篇博客文章中详细介绍。不管怎样,在这篇文章中我们介绍了两个工具:SaverSavedModel builder/loader,并创建一个文件结构,使用简单的线性回归来说明实例。希望这些能够帮助到你训练出更好的神经网络模型。

作者信息

1d0a417647e2b22e534a83c43b1b726531e8f4b5

作者:Mihajlo Pavloski数据科学与机器学习的爱好者,博士生。

本文由阿里云云社区组织翻译。

文章原标题《TensorFlow : Save and Restore Models

作者:Mihajlo Pavloski 译者:虎说八道,审阅:

文章为简译,更为详细的内容,请查看原文





相关文章
|
5月前
|
机器学习/深度学习 算法 定位技术
Baumer工业相机堡盟工业相机如何通过YoloV8深度学习模型实现裂缝的检测识别(C#代码UI界面版)
本项目基于YOLOv8模型与C#界面,结合Baumer工业相机,实现裂缝的高效检测识别。支持图像、视频及摄像头输入,具备高精度与实时性,适用于桥梁、路面、隧道等多种工业场景。
619 27
|
4月前
|
机器学习/深度学习 数据可视化 算法
深度学习模型结构复杂、参数众多,如何更直观地深入理解你的模型?
深度学习模型虽应用广泛,但其“黑箱”特性导致可解释性不足,尤其在金融、医疗等敏感领域,模型决策逻辑的透明性至关重要。本文聚焦深度学习可解释性中的可视化分析,介绍模型结构、特征、参数及输入激活的可视化方法,帮助理解模型行为、提升透明度,并推动其在关键领域的安全应用。
432 0
|
3月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
203 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
2月前
|
机器学习/深度学习 数据采集 人工智能
深度学习实战指南:从神经网络基础到模型优化的完整攻略
🌟 蒋星熠Jaxonic,AI探索者。深耕深度学习,从神经网络到Transformer,用代码践行智能革命。分享实战经验,助你构建CV、NLP模型,共赴二进制星辰大海。
|
5月前
|
机器学习/深度学习 人工智能 PyTorch
AI 基础知识从 0.2 到 0.3——构建你的第一个深度学习模型
本文以 MNIST 手写数字识别为切入点,介绍了深度学习的基本原理与实现流程,帮助读者建立起对神经网络建模过程的系统性理解。
645 15
AI 基础知识从 0.2 到 0.3——构建你的第一个深度学习模型
|
3月前
|
机器学习/深度学习 数据采集 传感器
【WOA-CNN-LSTM】基于鲸鱼算法优化深度学习预测模型的超参数研究(Matlab代码实现)
【WOA-CNN-LSTM】基于鲸鱼算法优化深度学习预测模型的超参数研究(Matlab代码实现)
246 0
|
5月前
|
机器学习/深度学习 人工智能 自然语言处理
AI 基础知识从 0.3 到 0.4——如何选对深度学习模型?
本系列文章从机器学习基础出发,逐步深入至深度学习与Transformer模型,探讨AI关键技术原理及应用。内容涵盖模型架构解析、典型模型对比、预训练与微调策略,并结合Hugging Face平台进行实战演示,适合初学者与开发者系统学习AI核心知识。
493 15
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
1021 55
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
1033 5
|
机器学习/深度学习 人工智能 TensorFlow
基于TensorFlow的深度学习模型训练与优化实战
基于TensorFlow的深度学习模型训练与优化实战
571 3

热门文章

最新文章