Tensorflow将模型导出为一个文件及接口设置

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
简介: Tensorflow将模型导出为一个文件及接口设置

Tensorflow将模型导出为一个文件及接口设置


最近看到一个巨牛的人工智能教程,分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。平时碎片时间可以当小说看,【点这里可以去膜拜一下大神的“小说”】。

在上一篇文章中《Tensorflow加载预训练模型和保存模型》,我们学习到如何使用预训练的模型。但注意到,在上一篇文章中使用预训练模型,必须至少的要4个文件:

checkpoint
MyModel.meta
MyModel.data-00000-of-00001
MyModel.index

这很不便于我们的使用。有没有办法导出为一个pb文件,然后直接使用呢?答案是肯定的。在文章《Tensorflow加载预训练模型和保存模型》中提到,meta文件保存图结构,weights等参数保存在data文件中。也就是说,图和参数数据时分开保存的。说的更直白一点,就是meta文件中没有weights等数据。但是,值得注意的是,**meta文件会保存常量。**我们只需将data文件中的参数转为meta文件中的常量即可!

1 模型导出为一个文件

1.1 有代码并且从头开始训练

Tensorflow提供了工具函数tf.graph_util.convert_variables_to_constants()用于将变量转为常量。看看官网的描述:

if you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables.

我们继续通过一个简单例子开始:

import tensorflow as tf
w1 = tf.Variable(20.0, name="w1")
w2 = tf.Variable(30.0, name="w2")
b1= tf.Variable(2.0,name="bias")
w3 = tf.add(w1,w2)
#记住要定义name,后面需要用到
out = tf.multiply(w3,b1,name="out")
# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 这里需要填入输出tensor的名字
    graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["out"])
    tf.train.write_graph(graph, '.', './checkpoint_dir/graph.pb', as_text=False)

执行可以看到如下日志:

Converted 3 variables to const ops.

可以看到通过tf.graph_util.convert_variables_to_constants()函数将变量转为了常量,并存储在graph.pb文件中,接下来看看如何使用这个模型。

import tensorflow as tf
with tf.Session() as sess:
    with open('./checkpoint_dir/graph.pb', 'rb') as graph:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(graph.read())
        output = tf.import_graph_def(graph_def, return_elements=['out:0'])
        print(sess.run(output))

运行结果如下:

[100.0]

回到tf.graph_util.convert_variables_to_constants()函数,可以看到,需要传入Session对象和图,这都可以理解。看看第三个参数["out"],它是指定这个模型的输出Tensor。

##1.2 有代码和模型,但是不想重新训练模型

有模型源码时,在导出模型时就可以通过tf.graph_util.convert_variables_to_constants()函数来将变量转为常量保存到图文件中。但是很多时候,我们拿到的是别人的checkpoint文件,即meta、index、data等文件。这种情况下,需要将data文件里面变量转为常量保存到meta文件中。思路也很简单,先将checkpoint文件加载,再重新保存一次即可。

假设训练和保存模型代码如下:

import tensorflow as tf
w1 = tf.Variable(20.0, name="w1")
w2 = tf.Variable(30.0, name="w2")
b1= tf.Variable(2.0,name="bias")
w3 = tf.add(w1,w2)
#记住要定义name,后面需要用到
out = tf.multiply(w3,b1,name="out")
# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    # 这里需要填入输出tensor的名字
    saver.save(sess, './checkpoint_dir/MyModel', global_step=1000)

此时,模型文件如下:

checkpoint
MyModel-1000.data-00000-of-00001
MyModel-1000.index
MyModel-1000.meta

如果我们只有以上4个模型文件,但是可以看到训练源码。那么,将这4个文件导出为一个pb文件方法如下:

import tensorflow as tf
with tf.Session() as sess:
    #初始化变量
    sess.run(tf.global_variables_initializer())
    #获取最新的checkpoint,其实就是解析了checkpoint文件
    latest_ckpt = tf.train.latest_checkpoint("./checkpoint_dir")
    #加载图
    restore_saver = tf.train.import_meta_graph('./checkpoint_dir/MyModel-1000.meta')
    #恢复图,即将weights等参数加入图对应位置中
    restore_saver.restore(sess, latest_ckpt)
    #将图中的变量转为常量
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess, sess.graph_def , ["out"])
    #将新的图保存到"/pretrained/graph.pb"文件中
    tf.train.write_graph(output_graph_def, 'pretrained', "graph.pb", as_text=False)

执行后,会有如下日志:

Converted 3 variables to const ops.

接下来就是使用,使用方法跟前面一致:

import tensorflow as tf
with tf.Session() as sess:
    with open('./pretrained/graph.pb', 'rb') as graph:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(graph.read())
        output = tf.import_graph_def(graph_def, return_elements=['out:0'])
        print(sess.run(output))

打印信息如下:

[100.0]

2 模型接口设置

我们注意到,前面只是简单的获取一个输出接口,但是很明显,我们使用的时候,不可能只有一个输出,还需要有输入,接下来我们看看,如何设置输入和输出。同样我们分为有代码并且从头开始训练,和有代码和模型,但是不想重新训练模型两种情况。

2.1 有代码并且从头开始训练

相比1.1中的代码略作修改即可,第6行代码处做了修改:

import tensorflow as tf
w1 = tf.Variable(20.0, name="w1")
w2 = tf.Variable(30.0, name="w2")
#这里将b1改为placeholder,让用户输入,而不是写死
#b1= tf.Variable(2.0,name="bias")
b1= tf.placeholder(tf.float32, name='bias')
w3 = tf.add(w1,w2)
#记住要定义name,后面需要用到
out = tf.multiply(w3,b1,name="out")
# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 这里需要填入输出tensor的名字
    graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["out"])
    tf.train.write_graph(graph, '.', './checkpoint_dir/graph.pb', as_text=False)

日志如下:

Converted 2 variables to const ops.

接下来看看如何使用:

import tensorflow as tf
with tf.Session() as sess:
    with open('./checkpoint_dir/graph.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        output = tf.import_graph_def(graph_def, input_map={'bias:0':4.}, return_elements=['out:0'])
        print(sess.run(output))

打印信息如下:

[200.0]

也就是说,在设置输入时,首先将需要输入的数据作为placeholdler,然后在导入图tf.import_graph_def()时,通过参数input_map={}来指定输入。输出通过return_elements=[]直接引用tensor的name即可。

2.2 有代码和模型,但是不想重新训练模型

在有代码和模型,但是不想重新训练模型情况下,意味着我们不能直接修改导出模型的代码。但是我们可以通过graph.get_tensor_by_name()函数取得图中的某些中间结果,然后再加入一些逻辑。其实这种情况在上一篇文章已经讲了。可以参考上一篇文件解决,相比“有代码并且从头开始训练”情况局限比较大,大部分情况只能是获取模型的一些中间结果,但是也满足我们大多数情况使用了。

相关文章
|
4月前
|
机器学习/深度学习 算法 TensorFlow
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
71 0
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
|
4月前
|
机器学习/深度学习 监控 Python
tensorflow2.x多层感知机模型参数量和计算量的统计
tensorflow2.x多层感知机模型参数量和计算量的统计
|
4月前
|
机器学习/深度学习 搜索推荐 算法
推荐系统离线评估方法和评估指标,以及在推荐服务器内部实现A/B测试和解决A/B测试资源紧张的方法。还介绍了如何在TensorFlow中进行模型离线评估实践。
推荐系统离线评估方法和评估指标,以及在推荐服务器内部实现A/B测试和解决A/B测试资源紧张的方法。还介绍了如何在TensorFlow中进行模型离线评估实践。
205 0
|
22小时前
|
机器学习/深度学习 TensorFlow 算法框架/工具
关于Tensorflow!目标检测预训练模型的迁移学习
这篇文章主要介绍了使用Tensorflow进行目标检测的迁移学习过程。关于使用Tensorflow进行目标检测模型训练的实战教程,涵盖了从数据准备到模型应用的全过程,特别适合对此领域感兴趣的开发者参考。
12 3
关于Tensorflow!目标检测预训练模型的迁移学习
|
2天前
|
机器学习/深度学习 TensorFlow API
Python深度学习基于Tensorflow(3)Tensorflow 构建模型
Python深度学习基于Tensorflow(3)Tensorflow 构建模型
11 2
|
11天前
|
机器学习/深度学习 数据可视化 TensorFlow
【Python 机器学习专栏】使用 TensorFlow 构建深度学习模型
【4月更文挑战第30天】本文介绍了如何使用 TensorFlow 构建深度学习模型。TensorFlow 是谷歌的开源深度学习框架,具备强大计算能力和灵活编程接口。构建模型涉及数据准备、模型定义、选择损失函数和优化器、训练、评估及模型保存部署。文中以全连接神经网络为例,展示了从数据预处理到模型训练和评估的完整流程。此外,还提到了 TensorFlow 的自动微分、模型可视化和分布式训练等高级特性。通过本文,读者可掌握 TensorFlow 基本用法,为构建高效深度学习模型打下基础。
|
13天前
|
机器学习/深度学习 算法 TensorFlow
TensorFlow 2keras开发深度学习模型实例:多层感知器(MLP),卷积神经网络(CNN)和递归神经网络(RNN)
TensorFlow 2keras开发深度学习模型实例:多层感知器(MLP),卷积神经网络(CNN)和递归神经网络(RNN)
|
15天前
|
机器学习/深度学习 TensorFlow API
Python安装TensorFlow 2、tf.keras和深度学习模型的定义
Python安装TensorFlow 2、tf.keras和深度学习模型的定义
|
22天前
|
机器学习/深度学习 大数据 TensorFlow
使用TensorFlow实现Python简版神经网络模型
使用TensorFlow实现Python简版神经网络模型
|
25天前
|
机器学习/深度学习 TensorFlow 算法框架/工具
TensorFlow的保存与加载模型
【4月更文挑战第17天】本文介绍了TensorFlow中模型的保存与加载。保存模型能节省训练时间,便于部署和复用。在TensorFlow中,可使用`save_model_to_hdf5`保存模型结构,`save_weights`保存权重,或转换为SavedModel格式。加载时,通过`load_model`恢复结构,`load_weights`加载权重。注意模型结构一致性、环境依赖及自定义层的兼容性问题。正确保存和加载能有效利用模型资源,提升效率和准确性。

热门文章

最新文章