Tensorflow将模型导出为一个文件及接口设置

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
模型训练 PAI-DLC,5000CU*H 3个月
交互式建模 PAI-DSW,每月250计算时 3个月
简介: Tensorflow将模型导出为一个文件及接口设置

Tensorflow将模型导出为一个文件及接口设置


最近看到一个巨牛的人工智能教程,分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。平时碎片时间可以当小说看,【点这里可以去膜拜一下大神的“小说”】。

在上一篇文章中《Tensorflow加载预训练模型和保存模型》,我们学习到如何使用预训练的模型。但注意到,在上一篇文章中使用预训练模型,必须至少的要4个文件:

checkpoint
MyModel.meta
MyModel.data-00000-of-00001
MyModel.index

这很不便于我们的使用。有没有办法导出为一个pb文件,然后直接使用呢?答案是肯定的。在文章《Tensorflow加载预训练模型和保存模型》中提到,meta文件保存图结构,weights等参数保存在data文件中。也就是说,图和参数数据时分开保存的。说的更直白一点,就是meta文件中没有weights等数据。但是,值得注意的是,**meta文件会保存常量。**我们只需将data文件中的参数转为meta文件中的常量即可!

1 模型导出为一个文件

1.1 有代码并且从头开始训练

Tensorflow提供了工具函数tf.graph_util.convert_variables_to_constants()用于将变量转为常量。看看官网的描述:

if you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables.

我们继续通过一个简单例子开始:

import tensorflow as tf
w1 = tf.Variable(20.0, name="w1")
w2 = tf.Variable(30.0, name="w2")
b1= tf.Variable(2.0,name="bias")
w3 = tf.add(w1,w2)
#记住要定义name,后面需要用到
out = tf.multiply(w3,b1,name="out")
# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 这里需要填入输出tensor的名字
    graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["out"])
    tf.train.write_graph(graph, '.', './checkpoint_dir/graph.pb', as_text=False)

执行可以看到如下日志:

Converted 3 variables to const ops.

可以看到通过tf.graph_util.convert_variables_to_constants()函数将变量转为了常量,并存储在graph.pb文件中,接下来看看如何使用这个模型。

import tensorflow as tf
with tf.Session() as sess:
    with open('./checkpoint_dir/graph.pb', 'rb') as graph:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(graph.read())
        output = tf.import_graph_def(graph_def, return_elements=['out:0'])
        print(sess.run(output))

运行结果如下:

[100.0]

回到tf.graph_util.convert_variables_to_constants()函数,可以看到,需要传入Session对象和图,这都可以理解。看看第三个参数["out"],它是指定这个模型的输出Tensor。

##1.2 有代码和模型,但是不想重新训练模型

有模型源码时,在导出模型时就可以通过tf.graph_util.convert_variables_to_constants()函数来将变量转为常量保存到图文件中。但是很多时候,我们拿到的是别人的checkpoint文件,即meta、index、data等文件。这种情况下,需要将data文件里面变量转为常量保存到meta文件中。思路也很简单,先将checkpoint文件加载,再重新保存一次即可。

假设训练和保存模型代码如下:

import tensorflow as tf
w1 = tf.Variable(20.0, name="w1")
w2 = tf.Variable(30.0, name="w2")
b1= tf.Variable(2.0,name="bias")
w3 = tf.add(w1,w2)
#记住要定义name,后面需要用到
out = tf.multiply(w3,b1,name="out")
# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    # 这里需要填入输出tensor的名字
    saver.save(sess, './checkpoint_dir/MyModel', global_step=1000)

此时,模型文件如下:

checkpoint
MyModel-1000.data-00000-of-00001
MyModel-1000.index
MyModel-1000.meta

如果我们只有以上4个模型文件,但是可以看到训练源码。那么,将这4个文件导出为一个pb文件方法如下:

import tensorflow as tf
with tf.Session() as sess:
    #初始化变量
    sess.run(tf.global_variables_initializer())
    #获取最新的checkpoint,其实就是解析了checkpoint文件
    latest_ckpt = tf.train.latest_checkpoint("./checkpoint_dir")
    #加载图
    restore_saver = tf.train.import_meta_graph('./checkpoint_dir/MyModel-1000.meta')
    #恢复图,即将weights等参数加入图对应位置中
    restore_saver.restore(sess, latest_ckpt)
    #将图中的变量转为常量
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess, sess.graph_def , ["out"])
    #将新的图保存到"/pretrained/graph.pb"文件中
    tf.train.write_graph(output_graph_def, 'pretrained', "graph.pb", as_text=False)

执行后,会有如下日志:

Converted 3 variables to const ops.

接下来就是使用,使用方法跟前面一致:

import tensorflow as tf
with tf.Session() as sess:
    with open('./pretrained/graph.pb', 'rb') as graph:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(graph.read())
        output = tf.import_graph_def(graph_def, return_elements=['out:0'])
        print(sess.run(output))

打印信息如下:

[100.0]

2 模型接口设置

我们注意到,前面只是简单的获取一个输出接口,但是很明显,我们使用的时候,不可能只有一个输出,还需要有输入,接下来我们看看,如何设置输入和输出。同样我们分为有代码并且从头开始训练,和有代码和模型,但是不想重新训练模型两种情况。

2.1 有代码并且从头开始训练

相比1.1中的代码略作修改即可,第6行代码处做了修改:

import tensorflow as tf
w1 = tf.Variable(20.0, name="w1")
w2 = tf.Variable(30.0, name="w2")
#这里将b1改为placeholder,让用户输入,而不是写死
#b1= tf.Variable(2.0,name="bias")
b1= tf.placeholder(tf.float32, name='bias')
w3 = tf.add(w1,w2)
#记住要定义name,后面需要用到
out = tf.multiply(w3,b1,name="out")
# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 这里需要填入输出tensor的名字
    graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["out"])
    tf.train.write_graph(graph, '.', './checkpoint_dir/graph.pb', as_text=False)

日志如下:

Converted 2 variables to const ops.

接下来看看如何使用:

import tensorflow as tf
with tf.Session() as sess:
    with open('./checkpoint_dir/graph.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        output = tf.import_graph_def(graph_def, input_map={'bias:0':4.}, return_elements=['out:0'])
        print(sess.run(output))

打印信息如下:

[200.0]

也就是说,在设置输入时,首先将需要输入的数据作为placeholdler,然后在导入图tf.import_graph_def()时,通过参数input_map={}来指定输入。输出通过return_elements=[]直接引用tensor的name即可。

2.2 有代码和模型,但是不想重新训练模型

在有代码和模型,但是不想重新训练模型情况下,意味着我们不能直接修改导出模型的代码。但是我们可以通过graph.get_tensor_by_name()函数取得图中的某些中间结果,然后再加入一些逻辑。其实这种情况在上一篇文章已经讲了。可以参考上一篇文件解决,相比“有代码并且从头开始训练”情况局限比较大,大部分情况只能是获取模型的一些中间结果,但是也满足我们大多数情况使用了。

相关文章
|
1月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
83 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
5天前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
21 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
5天前
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
25 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
21天前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
65 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
2月前
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
108 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
|
1月前
|
机器学习/深度学习 移动开发 TensorFlow
深度学习之格式转换笔记(四):Keras(.h5)模型转化为TensorFlow(.pb)模型
本文介绍了如何使用Python脚本将Keras模型转换为TensorFlow的.pb格式模型,包括加载模型、重命名输出节点和量化等步骤,以便在TensorFlow中进行部署和推理。
76 0
|
3月前
|
API UED 开发者
如何在Uno Platform中轻松实现流畅动画效果——从基础到优化,全方位打造用户友好的动态交互体验!
【8月更文挑战第31天】在开发跨平台应用时,确保用户界面流畅且具吸引力至关重要。Uno Platform 作为多端统一的开发框架,不仅支持跨系统应用开发,还能通过优化实现流畅动画,增强用户体验。本文探讨了Uno Platform中实现流畅动画的多个方面,包括动画基础、性能优化、实践技巧及问题排查,帮助开发者掌握具体优化策略,提升应用质量与用户满意度。通过合理利用故事板、减少布局复杂性、使用硬件加速等技术,结合异步方法与预设缓存技巧,开发者能够创建美观且流畅的动画效果。
79 0
|
3月前
|
C# 开发者 前端开发
揭秘混合开发新趋势:Uno Platform携手Blazor,教你一步到位实现跨平台应用,代码复用不再是梦!
【8月更文挑战第31天】随着前端技术的发展,混合开发日益受到开发者青睐。本文详述了如何结合.NET生态下的两大框架——Uno Platform与Blazor,进行高效混合开发。Uno Platform基于WebAssembly和WebGL技术,支持跨平台应用构建;Blazor则让C#成为可能的前端开发语言,实现了客户端与服务器端逻辑共享。二者结合不仅提升了代码复用率与跨平台能力,还简化了项目维护并增强了Web应用性能。文中提供了从环境搭建到示例代码的具体步骤,并展示了如何创建一个简单的计数器应用,帮助读者快速上手混合开发。
83 0
|
3月前
|
开发者 算法 虚拟化
惊爆!Uno Platform 调试与性能分析终极攻略,从工具运用到代码优化,带你攻克开发难题成就完美应用
【8月更文挑战第31天】在 Uno Platform 中,调试可通过 Visual Studio 设置断点和逐步执行代码实现,同时浏览器开发者工具有助于 Web 版本调试。性能分析则利用 Visual Studio 的性能分析器检查 CPU 和内存使用情况,还可通过记录时间戳进行简单分析。优化性能涉及代码逻辑优化、资源管理和用户界面简化,综合利用平台提供的工具和技术,确保应用高效稳定运行。
83 0
|
3月前
|
前端开发 开发者 设计模式
揭秘Uno Platform状态管理之道:INotifyPropertyChanged、依赖注入、MVVM大对决,帮你找到最佳策略!
【8月更文挑战第31天】本文对比分析了 Uno Platform 中的关键状态管理策略,包括内置的 INotifyPropertyChanged、依赖注入及 MVVM 框架。INotifyPropertyChanged 方案简单易用,适合小型项目;依赖注入则更灵活,支持状态共享与持久化,适用于复杂场景;MVVM 框架通过分离视图、视图模型和模型,使状态管理更清晰,适合大型项目。开发者可根据项目需求和技术栈选择合适的状态管理方案,以实现高效管理。
43 0

热门文章

最新文章